edit_prediction.rs

   1use anyhow::Result;
   2use client::{Client, EditPredictionUsage, UserStore};
   3use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KeyValueStore};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use heapless::Vec as ArrayVec;
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::{AppState, Workspace};
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant};
  54
  55use thiserror::Error;
  56use util::{RangeExt as _, ResultExt as _};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  79use crate::example_spec::ExampleSpec;
  80use crate::license_detection::LicenseDetectionWatcher;
  81use crate::mercury::Mercury;
  82use crate::onboarding_modal::ZedPredictModal;
  83pub use crate::prediction::EditPrediction;
  84pub use crate::prediction::EditPredictionId;
  85use crate::prediction::EditPredictionResult;
  86pub use capture_example::capture_example;
  87pub use language_model::ApiKeyState;
  88pub use telemetry_events::EditPredictionRating;
  89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  90
  91actions!(
  92    edit_prediction,
  93    [
  94        /// Resets the edit prediction onboarding state.
  95        ResetOnboarding,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 10;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 112
 113pub struct EditPredictionJumpsFeatureFlag;
 114
 115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 116    const NAME: &'static str = "edit_prediction_jumps";
 117}
 118
 119#[derive(Clone)]
 120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 121
 122impl Global for EditPredictionStoreGlobal {}
 123
 124/// Configuration for using the raw Zeta2 endpoint.
 125/// When set, the client uses the raw endpoint and constructs the prompt itself.
 126/// The version is also used as the Baseten environment name (lowercased).
 127#[derive(Clone)]
 128pub struct Zeta2RawConfig {
 129    pub model_id: Option<String>,
 130    pub environment: Option<String>,
 131    pub format: ZetaFormat,
 132}
 133
 134pub struct EditPredictionStore {
 135    client: Arc<Client>,
 136    user_store: Entity<UserStore>,
 137    llm_token: LlmApiToken,
 138    _fetch_experiments_task: Task<()>,
 139    projects: HashMap<EntityId, ProjectState>,
 140    update_required: bool,
 141    edit_prediction_model: EditPredictionModel,
 142    zeta2_raw_config: Option<Zeta2RawConfig>,
 143    preferred_experiment: Option<String>,
 144    available_experiments: Vec<String>,
 145    pub mercury: Mercury,
 146    data_collection_choice: DataCollectionChoice,
 147    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 148    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    #[cfg(test)]
 152    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 153}
 154
 155pub(crate) struct EditPredictionRejectionPayload {
 156    rejection: EditPredictionRejection,
 157    organization_id: Option<OrganizationId>,
 158}
 159
 160#[derive(Copy, Clone, PartialEq, Eq)]
 161pub enum EditPredictionModel {
 162    Zeta,
 163    Fim { format: EditPredictionPromptFormat },
 164    Mercury,
 165}
 166
 167#[derive(Clone)]
 168pub struct EditPredictionModelInput {
 169    project: Entity<Project>,
 170    buffer: Entity<Buffer>,
 171    snapshot: BufferSnapshot,
 172    position: Anchor,
 173    events: Vec<Arc<zeta_prompt::Event>>,
 174    related_files: Vec<RelatedFile>,
 175    trigger: PredictEditsRequestTrigger,
 176    diagnostic_search_range: Range<Point>,
 177    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 178    can_collect_data: bool,
 179    is_open_source: bool,
 180}
 181
 182#[derive(Debug)]
 183pub enum DebugEvent {
 184    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 185    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 186    EditPredictionStarted(EditPredictionStartedDebugEvent),
 187    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 188}
 189
 190#[derive(Debug)]
 191pub struct ContextRetrievalStartedDebugEvent {
 192    pub project_entity_id: EntityId,
 193    pub timestamp: Instant,
 194    pub search_prompt: String,
 195}
 196
 197#[derive(Debug)]
 198pub struct ContextRetrievalFinishedDebugEvent {
 199    pub project_entity_id: EntityId,
 200    pub timestamp: Instant,
 201    pub metadata: Vec<(&'static str, SharedString)>,
 202}
 203
 204#[derive(Debug)]
 205pub struct EditPredictionStartedDebugEvent {
 206    pub buffer: WeakEntity<Buffer>,
 207    pub position: Anchor,
 208    pub prompt: Option<String>,
 209}
 210
 211#[derive(Debug)]
 212pub struct EditPredictionFinishedDebugEvent {
 213    pub buffer: WeakEntity<Buffer>,
 214    pub position: Anchor,
 215    pub model_output: Option<String>,
 216}
 217
 218/// An event with associated metadata for reconstructing buffer state.
 219#[derive(Clone)]
 220pub struct StoredEvent {
 221    pub event: Arc<zeta_prompt::Event>,
 222    pub old_snapshot: TextBufferSnapshot,
 223    pub new_snapshot_version: clock::Global,
 224    pub total_edit_range: Range<Anchor>,
 225}
 226
 227impl StoredEvent {
 228    fn can_merge(
 229        &self,
 230        next_old_event: &StoredEvent,
 231        latest_snapshot: &TextBufferSnapshot,
 232        latest_edit_range: &Range<Anchor>,
 233    ) -> bool {
 234        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 235        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 236            return false;
 237        }
 238        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 239            return false;
 240        }
 241        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 242            return false;
 243        }
 244        if !latest_snapshot
 245            .version
 246            .observed_all(&next_old_event.new_snapshot_version)
 247        {
 248            return false;
 249        }
 250
 251        let a_is_predicted = matches!(
 252            self.event.as_ref(),
 253            zeta_prompt::Event::BufferChange {
 254                predicted: true,
 255                ..
 256            }
 257        );
 258        let b_is_predicted = matches!(
 259            next_old_event.event.as_ref(),
 260            zeta_prompt::Event::BufferChange {
 261                predicted: true,
 262                ..
 263            }
 264        );
 265
 266        // If events come from the same source (both predicted or both manual) then
 267        // we would have coalesced them already.
 268        if a_is_predicted == b_is_predicted {
 269            return false;
 270        }
 271
 272        let left_range = self.total_edit_range.to_point(latest_snapshot);
 273        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 274        let latest_range = latest_edit_range.to_point(latest_snapshot);
 275
 276        // Events near to the latest edit are not merged if their sources differ.
 277        if lines_between_ranges(&left_range, &latest_range)
 278            .min(lines_between_ranges(&right_range, &latest_range))
 279            <= CHANGE_GROUPING_LINE_SPAN
 280        {
 281            return false;
 282        }
 283
 284        // Events that are distant from each other are not merged.
 285        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 286            return false;
 287        }
 288
 289        true
 290    }
 291}
 292
 293fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 294    if left.start > right.end {
 295        return left.start.row - right.end.row;
 296    }
 297    if right.start > left.end {
 298        return right.start.row - left.end.row;
 299    }
 300    0
 301}
 302
 303struct ProjectState {
 304    events: VecDeque<StoredEvent>,
 305    last_event: Option<LastEvent>,
 306    recent_paths: VecDeque<ProjectPath>,
 307    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 308    current_prediction: Option<CurrentEditPrediction>,
 309    next_pending_prediction_id: usize,
 310    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 311    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 312    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 313    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 314    cancelled_predictions: HashSet<usize>,
 315    context: Entity<RelatedExcerptStore>,
 316    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 317    _subscriptions: [gpui::Subscription; 2],
 318    copilot: Option<Entity<Copilot>>,
 319}
 320
 321impl ProjectState {
 322    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 323        self.events
 324            .iter()
 325            .cloned()
 326            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 327                let (one, two) = event.split_by_pause();
 328                let one = one.finalize(&self.license_detection_watchers, cx);
 329                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 330                one.into_iter().chain(two)
 331            }))
 332            .collect()
 333    }
 334
 335    fn cancel_pending_prediction(
 336        &mut self,
 337        pending_prediction: PendingPrediction,
 338        cx: &mut Context<EditPredictionStore>,
 339    ) {
 340        self.cancelled_predictions.insert(pending_prediction.id);
 341
 342        if pending_prediction.drop_on_cancel {
 343            drop(pending_prediction.task);
 344        } else {
 345            cx.spawn(async move |this, cx| {
 346                let Some(prediction_id) = pending_prediction.task.await else {
 347                    return;
 348                };
 349
 350                this.update(cx, |this, cx| {
 351                    this.reject_prediction(
 352                        prediction_id,
 353                        EditPredictionRejectReason::Canceled,
 354                        false,
 355                        None,
 356                        None,
 357                        cx,
 358                    );
 359                })
 360                .ok();
 361            })
 362            .detach()
 363        }
 364    }
 365
 366    fn active_buffer(
 367        &self,
 368        project: &Entity<Project>,
 369        cx: &App,
 370    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 371        let project = project.read(cx);
 372        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 373        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 374        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 375        Some((active_buffer, registered_buffer.last_position))
 376    }
 377}
 378
 379#[derive(Debug, Clone)]
 380struct CurrentEditPrediction {
 381    pub requested_by: PredictionRequestedBy,
 382    pub prediction: EditPrediction,
 383    pub was_shown: bool,
 384    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 385    pub e2e_latency: std::time::Duration,
 386}
 387
 388impl CurrentEditPrediction {
 389    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 390        let Some(new_edits) = self
 391            .prediction
 392            .interpolate(&self.prediction.buffer.read(cx))
 393        else {
 394            return false;
 395        };
 396
 397        if self.prediction.buffer != old_prediction.prediction.buffer {
 398            return true;
 399        }
 400
 401        let Some(old_edits) = old_prediction
 402            .prediction
 403            .interpolate(&old_prediction.prediction.buffer.read(cx))
 404        else {
 405            return true;
 406        };
 407
 408        let requested_by_buffer_id = self.requested_by.buffer_id();
 409
 410        // This reduces the occurrence of UI thrash from replacing edits
 411        //
 412        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 413        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 414            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 415            && old_edits.len() == 1
 416            && new_edits.len() == 1
 417        {
 418            let (old_range, old_text) = &old_edits[0];
 419            let (new_range, new_text) = &new_edits[0];
 420            new_range == old_range && new_text.starts_with(old_text.as_ref())
 421        } else {
 422            true
 423        }
 424    }
 425}
 426
 427#[derive(Debug, Clone)]
 428enum PredictionRequestedBy {
 429    DiagnosticsUpdate,
 430    Buffer(EntityId),
 431}
 432
 433impl PredictionRequestedBy {
 434    pub fn buffer_id(&self) -> Option<EntityId> {
 435        match self {
 436            PredictionRequestedBy::DiagnosticsUpdate => None,
 437            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 438        }
 439    }
 440}
 441
 442const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 443
 444#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 445pub enum DiagnosticSearchScope {
 446    Local,
 447    Global,
 448}
 449
 450#[derive(Debug)]
 451struct PendingPrediction {
 452    id: usize,
 453    task: Task<Option<EditPredictionId>>,
 454    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 455    /// If false, the task is awaited to completion so rejection can be reported.
 456    drop_on_cancel: bool,
 457}
 458
 459/// A prediction from the perspective of a buffer.
 460#[derive(Debug)]
 461enum BufferEditPrediction<'a> {
 462    Local { prediction: &'a EditPrediction },
 463    Jump { prediction: &'a EditPrediction },
 464}
 465
 466#[cfg(test)]
 467impl std::ops::Deref for BufferEditPrediction<'_> {
 468    type Target = EditPrediction;
 469
 470    fn deref(&self) -> &Self::Target {
 471        match self {
 472            BufferEditPrediction::Local { prediction } => prediction,
 473            BufferEditPrediction::Jump { prediction } => prediction,
 474        }
 475    }
 476}
 477
 478#[derive(Clone)]
 479
 480struct PendingSettledPrediction {
 481    request_id: EditPredictionId,
 482    editable_anchor_range: Range<Anchor>,
 483    example: Option<ExampleSpec>,
 484    enqueued_at: Instant,
 485    last_edit_at: Instant,
 486    e2e_latency: std::time::Duration,
 487}
 488
 489struct RegisteredBuffer {
 490    file: Option<Arc<dyn File>>,
 491    snapshot: TextBufferSnapshot,
 492    pending_predictions: Vec<PendingSettledPrediction>,
 493    last_position: Option<Anchor>,
 494    _subscriptions: [gpui::Subscription; 2],
 495}
 496
 497#[derive(Clone)]
 498struct LastEvent {
 499    old_snapshot: TextBufferSnapshot,
 500    new_snapshot: TextBufferSnapshot,
 501    old_file: Option<Arc<dyn File>>,
 502    new_file: Option<Arc<dyn File>>,
 503    latest_edit_range: Range<Anchor>,
 504    total_edit_range: Range<Anchor>,
 505    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 506    predicted: bool,
 507    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 508    last_edit_time: Option<Instant>,
 509}
 510
 511impl LastEvent {
 512    pub fn finalize(
 513        &self,
 514        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 515        cx: &App,
 516    ) -> Option<StoredEvent> {
 517        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 518        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 519
 520        let in_open_source_repo =
 521            [self.new_file.as_ref(), self.old_file.as_ref()]
 522                .iter()
 523                .all(|file| {
 524                    file.is_some_and(|file| {
 525                        license_detection_watchers
 526                            .get(&file.worktree_id(cx))
 527                            .is_some_and(|watcher| watcher.is_project_open_source())
 528                    })
 529                });
 530
 531        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 532            &self.old_snapshot,
 533            &self.new_snapshot,
 534            &self.total_edit_range,
 535        )?;
 536
 537        if path == old_path && diff.is_empty() {
 538            None
 539        } else {
 540            Some(StoredEvent {
 541                event: Arc::new(zeta_prompt::Event::BufferChange {
 542                    old_path,
 543                    path,
 544                    diff,
 545                    in_open_source_repo,
 546                    predicted: self.predicted,
 547                }),
 548                old_snapshot: self.old_snapshot.clone(),
 549                new_snapshot_version: self.new_snapshot.version.clone(),
 550                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 551                    ..self.new_snapshot.anchor_before(edit_range.end),
 552            })
 553        }
 554    }
 555
 556    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 557        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 558            return (self.clone(), None);
 559        };
 560
 561        let total_edit_range_before_pause = self
 562            .total_edit_range_at_last_pause_boundary
 563            .clone()
 564            .unwrap_or_else(|| self.total_edit_range.clone());
 565
 566        let Some(total_edit_range_after_pause) =
 567            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 568        else {
 569            return (self.clone(), None);
 570        };
 571
 572        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 573        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 574
 575        let before = LastEvent {
 576            old_snapshot: self.old_snapshot.clone(),
 577            new_snapshot: boundary_snapshot.clone(),
 578            old_file: self.old_file.clone(),
 579            new_file: self.new_file.clone(),
 580            latest_edit_range: latest_edit_range_before_pause,
 581            total_edit_range: total_edit_range_before_pause,
 582            total_edit_range_at_last_pause_boundary: None,
 583            predicted: self.predicted,
 584            snapshot_after_last_editing_pause: None,
 585            last_edit_time: self.last_edit_time,
 586        };
 587
 588        let after = LastEvent {
 589            old_snapshot: boundary_snapshot.clone(),
 590            new_snapshot: self.new_snapshot.clone(),
 591            old_file: self.old_file.clone(),
 592            new_file: self.new_file.clone(),
 593            latest_edit_range: latest_edit_range_after_pause,
 594            total_edit_range: total_edit_range_after_pause,
 595            total_edit_range_at_last_pause_boundary: None,
 596            predicted: self.predicted,
 597            snapshot_after_last_editing_pause: None,
 598            last_edit_time: self.last_edit_time,
 599        };
 600
 601        (before, Some(after))
 602    }
 603}
 604
 605fn compute_total_edit_range_between_snapshots(
 606    old_snapshot: &TextBufferSnapshot,
 607    new_snapshot: &TextBufferSnapshot,
 608) -> Option<Range<Anchor>> {
 609    let edits: Vec<Edit<usize>> = new_snapshot
 610        .edits_since::<usize>(&old_snapshot.version)
 611        .collect();
 612
 613    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 614    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 615    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 616
 617    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 618}
 619
 620fn compute_old_range_for_new_range(
 621    old_snapshot: &TextBufferSnapshot,
 622    new_snapshot: &TextBufferSnapshot,
 623    total_edit_range: &Range<Anchor>,
 624) -> Option<Range<Point>> {
 625    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 626    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 627
 628    let edits: Vec<Edit<usize>> = new_snapshot
 629        .edits_since::<usize>(&old_snapshot.version)
 630        .collect();
 631    let mut old_start_offset = None;
 632    let mut old_end_offset = None;
 633    let mut delta: isize = 0;
 634
 635    for edit in &edits {
 636        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 637            old_start_offset = Some(if new_start_offset < edit.new.start {
 638                new_start_offset.checked_add_signed(-delta)?
 639            } else {
 640                edit.old.start
 641            });
 642        }
 643
 644        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 645            old_end_offset = Some(if new_end_offset < edit.new.start {
 646                new_end_offset.checked_add_signed(-delta)?
 647            } else {
 648                edit.old.end
 649            });
 650        }
 651
 652        delta += edit.new.len() as isize - edit.old.len() as isize;
 653    }
 654
 655    let old_start_offset =
 656        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 657    let old_end_offset =
 658        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 659
 660    Some(
 661        old_snapshot.offset_to_point(old_start_offset)
 662            ..old_snapshot.offset_to_point(old_end_offset),
 663    )
 664}
 665
 666fn compute_diff_between_snapshots_in_range(
 667    old_snapshot: &TextBufferSnapshot,
 668    new_snapshot: &TextBufferSnapshot,
 669    total_edit_range: &Range<Anchor>,
 670) -> Option<(String, Range<Point>)> {
 671    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 672    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 673    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 674    let old_start_point = old_range.start;
 675    let old_end_point = old_range.end;
 676
 677    const CONTEXT_LINES: u32 = 3;
 678
 679    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 680    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 681    let old_context_end_row =
 682        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 683    let new_context_end_row =
 684        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 685
 686    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 687    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 688    let old_end_line_offset = old_snapshot
 689        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 690    let new_end_line_offset = new_snapshot
 691        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 692    let old_edit_range = old_start_line_offset..old_end_line_offset;
 693    let new_edit_range = new_start_line_offset..new_end_line_offset;
 694
 695    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 696        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 697    {
 698        return None;
 699    }
 700
 701    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 702    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 703
 704    let diff = language::unified_diff_with_offsets(
 705        &old_region_text,
 706        &new_region_text,
 707        old_context_start_row,
 708        new_context_start_row,
 709    );
 710
 711    Some((diff, new_start_point..new_end_point))
 712}
 713
 714fn buffer_path_with_id_fallback(
 715    file: Option<&Arc<dyn File>>,
 716    snapshot: &TextBufferSnapshot,
 717    cx: &App,
 718) -> Arc<Path> {
 719    if let Some(file) = file {
 720        file.full_path(cx).into()
 721    } else {
 722        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 723    }
 724}
 725
 726impl EditPredictionStore {
 727    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 728        cx.try_global::<EditPredictionStoreGlobal>()
 729            .map(|global| global.0.clone())
 730    }
 731
 732    pub fn global(
 733        client: &Arc<Client>,
 734        user_store: &Entity<UserStore>,
 735        cx: &mut App,
 736    ) -> Entity<Self> {
 737        cx.try_global::<EditPredictionStoreGlobal>()
 738            .map(|global| global.0.clone())
 739            .unwrap_or_else(|| {
 740                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 741                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 742                ep_store
 743            })
 744    }
 745
 746    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 747        let data_collection_choice = Self::load_data_collection_choice(cx);
 748
 749        let llm_token = LlmApiToken::global(cx);
 750
 751        let (reject_tx, reject_rx) = mpsc::unbounded();
 752        cx.background_spawn({
 753            let client = client.clone();
 754            let llm_token = llm_token.clone();
 755            let app_version = AppVersion::global(cx);
 756            let background_executor = cx.background_executor().clone();
 757            async move {
 758                Self::handle_rejected_predictions(
 759                    reject_rx,
 760                    client,
 761                    llm_token,
 762                    app_version,
 763                    background_executor,
 764                )
 765                .await
 766            }
 767        })
 768        .detach();
 769
 770        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 771        cx.spawn(async move |this, cx| {
 772            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 773        })
 774        .detach();
 775
 776        let mut current_user = user_store.read(cx).watch_current_user();
 777        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 778            while current_user.borrow().is_none() {
 779                current_user.next().await;
 780            }
 781
 782            this.update(cx, |this, cx| {
 783                if cx.is_staff() {
 784                    this.refresh_available_experiments(cx);
 785                }
 786            })
 787            .log_err();
 788        });
 789
 790        let this = Self {
 791            projects: HashMap::default(),
 792            client,
 793            user_store,
 794            llm_token,
 795            _fetch_experiments_task: fetch_experiments_task,
 796            update_required: false,
 797            edit_prediction_model: EditPredictionModel::Zeta,
 798            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 799            preferred_experiment: None,
 800            available_experiments: Vec::new(),
 801            mercury: Mercury::new(cx),
 802
 803            data_collection_choice,
 804            reject_predictions_tx: reject_tx,
 805            settled_predictions_tx,
 806            rated_predictions: Default::default(),
 807            shown_predictions: Default::default(),
 808            #[cfg(test)]
 809            settled_event_callback: None,
 810        };
 811
 812        this
 813    }
 814
 815    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 816        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 817        let format = ZetaFormat::parse(&version_str).ok()?;
 818        let model_id = env::var("ZED_ZETA_MODEL").ok();
 819        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 820        Some(Zeta2RawConfig {
 821            model_id,
 822            environment,
 823            format,
 824        })
 825    }
 826
 827    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 828        self.edit_prediction_model = model;
 829    }
 830
 831    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 832        self.zeta2_raw_config = Some(config);
 833    }
 834
 835    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 836        self.zeta2_raw_config.as_ref()
 837    }
 838
 839    pub fn preferred_experiment(&self) -> Option<&str> {
 840        self.preferred_experiment.as_deref()
 841    }
 842
 843    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 844        self.preferred_experiment = experiment;
 845    }
 846
 847    pub fn available_experiments(&self) -> &[String] {
 848        &self.available_experiments
 849    }
 850
 851    pub fn active_experiment(&self) -> Option<&str> {
 852        self.preferred_experiment.as_deref().or_else(|| {
 853            self.shown_predictions
 854                .iter()
 855                .find_map(|p| p.model_version.as_ref())
 856                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 857        })
 858    }
 859
 860    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 861        let client = self.client.clone();
 862        let llm_token = self.llm_token.clone();
 863        let app_version = AppVersion::global(cx);
 864        let organization_id = self
 865            .user_store
 866            .read(cx)
 867            .current_organization()
 868            .map(|organization| organization.id.clone());
 869
 870        cx.spawn(async move |this, cx| {
 871            let experiments = cx
 872                .background_spawn(async move {
 873                    let http_client = client.http_client();
 874                    let token = llm_token.acquire(&client, organization_id).await?;
 875                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 876                    let request = http_client::Request::builder()
 877                        .method(Method::GET)
 878                        .uri(url.as_ref())
 879                        .header("Authorization", format!("Bearer {}", token))
 880                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 881                        .body(Default::default())?;
 882                    let mut response = http_client.send(request).await?;
 883                    if response.status().is_success() {
 884                        let mut body = Vec::new();
 885                        response.body_mut().read_to_end(&mut body).await?;
 886                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 887                        Ok(experiments)
 888                    } else {
 889                        let mut body = String::new();
 890                        response.body_mut().read_to_string(&mut body).await?;
 891                        anyhow::bail!(
 892                            "Failed to fetch experiments: {:?}\nBody: {}",
 893                            response.status(),
 894                            body
 895                        );
 896                    }
 897                })
 898                .await?;
 899            this.update(cx, |this, cx| {
 900                this.available_experiments = experiments;
 901                cx.notify();
 902            })?;
 903            anyhow::Ok(())
 904        })
 905        .detach_and_log_err(cx);
 906    }
 907
 908    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 909        use ui::IconName;
 910        match self.edit_prediction_model {
 911            EditPredictionModel::Mercury => {
 912                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 913            }
 914            EditPredictionModel::Zeta => {
 915                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 916                    .with_disabled(IconName::ZedPredictDisabled)
 917                    .with_up(IconName::ZedPredictUp)
 918                    .with_down(IconName::ZedPredictDown)
 919                    .with_error(IconName::ZedPredictError)
 920            }
 921            EditPredictionModel::Fim { .. } => {
 922                let settings = &all_language_settings(None, cx).edit_predictions;
 923                match settings.provider {
 924                    EditPredictionProvider::Ollama => {
 925                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 926                    }
 927                    _ => {
 928                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 929                    }
 930                }
 931            }
 932        }
 933    }
 934
 935    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 936        self.mercury.api_token.read(cx).has_key()
 937    }
 938
 939    pub fn mercury_has_payment_required_error(&self) -> bool {
 940        self.mercury.has_payment_required_error()
 941    }
 942
 943    pub fn clear_history(&mut self) {
 944        for project_state in self.projects.values_mut() {
 945            project_state.events.clear();
 946            project_state.last_event.take();
 947        }
 948    }
 949
 950    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 951        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 952            project_state.events.clear();
 953            project_state.last_event.take();
 954        }
 955    }
 956
 957    pub fn edit_history_for_project(
 958        &self,
 959        project: &Entity<Project>,
 960        cx: &App,
 961    ) -> Vec<StoredEvent> {
 962        self.projects
 963            .get(&project.entity_id())
 964            .map(|project_state| project_state.events(cx))
 965            .unwrap_or_default()
 966    }
 967
 968    pub fn context_for_project<'a>(
 969        &'a self,
 970        project: &Entity<Project>,
 971        cx: &'a mut App,
 972    ) -> Vec<RelatedFile> {
 973        self.projects
 974            .get(&project.entity_id())
 975            .map(|project_state| {
 976                project_state.context.update(cx, |context, cx| {
 977                    context
 978                        .related_files_with_buffers(cx)
 979                        .map(|(mut related_file, buffer)| {
 980                            related_file.in_open_source_repo = buffer
 981                                .read(cx)
 982                                .file()
 983                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 984                            related_file
 985                        })
 986                        .collect()
 987                })
 988            })
 989            .unwrap_or_default()
 990    }
 991
 992    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 993        self.projects
 994            .get(&project.entity_id())
 995            .and_then(|project| project.copilot.clone())
 996    }
 997
 998    pub fn start_copilot_for_project(
 999        &mut self,
1000        project: &Entity<Project>,
1001        cx: &mut Context<Self>,
1002    ) -> Option<Entity<Copilot>> {
1003        if DisableAiSettings::get(None, cx).disable_ai {
1004            return None;
1005        }
1006        let state = self.get_or_init_project(project, cx);
1007
1008        if state.copilot.is_some() {
1009            return state.copilot.clone();
1010        }
1011        let _project = project.clone();
1012        let project = project.read(cx);
1013
1014        let node = project.node_runtime().cloned();
1015        if let Some(node) = node {
1016            let next_id = project.languages().next_language_server_id();
1017            let fs = project.fs().clone();
1018
1019            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1020            state.copilot = Some(copilot.clone());
1021            Some(copilot)
1022        } else {
1023            None
1024        }
1025    }
1026
1027    pub fn context_for_project_with_buffers<'a>(
1028        &'a self,
1029        project: &Entity<Project>,
1030        cx: &'a mut App,
1031    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1032        self.projects
1033            .get(&project.entity_id())
1034            .map(|project| {
1035                project.context.update(cx, |context, cx| {
1036                    context.related_files_with_buffers(cx).collect()
1037                })
1038            })
1039            .unwrap_or_default()
1040    }
1041
1042    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1043        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1044            self.user_store.read(cx).edit_prediction_usage()
1045        } else {
1046            None
1047        }
1048    }
1049
1050    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1051        self.get_or_init_project(project, cx);
1052    }
1053
1054    pub fn register_buffer(
1055        &mut self,
1056        buffer: &Entity<Buffer>,
1057        project: &Entity<Project>,
1058        cx: &mut Context<Self>,
1059    ) {
1060        let project_state = self.get_or_init_project(project, cx);
1061        Self::register_buffer_impl(project_state, buffer, project, cx);
1062    }
1063
1064    fn get_or_init_project(
1065        &mut self,
1066        project: &Entity<Project>,
1067        cx: &mut Context<Self>,
1068    ) -> &mut ProjectState {
1069        let entity_id = project.entity_id();
1070        self.projects
1071            .entry(entity_id)
1072            .or_insert_with(|| ProjectState {
1073                context: {
1074                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1075                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1076                        this.handle_excerpt_store_event(entity_id, event);
1077                    })
1078                    .detach();
1079                    related_excerpt_store
1080                },
1081                events: VecDeque::new(),
1082                last_event: None,
1083                recent_paths: VecDeque::new(),
1084                debug_tx: None,
1085                registered_buffers: HashMap::default(),
1086                current_prediction: None,
1087                cancelled_predictions: HashSet::default(),
1088                pending_predictions: ArrayVec::new(),
1089                next_pending_prediction_id: 0,
1090                last_edit_prediction_refresh: None,
1091                last_jump_prediction_refresh: None,
1092                license_detection_watchers: HashMap::default(),
1093                _subscriptions: [
1094                    cx.subscribe(&project, Self::handle_project_event),
1095                    cx.observe_release(&project, move |this, _, cx| {
1096                        this.projects.remove(&entity_id);
1097                        cx.notify();
1098                    }),
1099                ],
1100                copilot: None,
1101            })
1102    }
1103
1104    pub fn remove_project(&mut self, project: &Entity<Project>) {
1105        self.projects.remove(&project.entity_id());
1106    }
1107
1108    fn handle_excerpt_store_event(
1109        &mut self,
1110        project_entity_id: EntityId,
1111        event: &RelatedExcerptStoreEvent,
1112    ) {
1113        if let Some(project_state) = self.projects.get(&project_entity_id) {
1114            if let Some(debug_tx) = project_state.debug_tx.clone() {
1115                match event {
1116                    RelatedExcerptStoreEvent::StartedRefresh => {
1117                        debug_tx
1118                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1119                                ContextRetrievalStartedDebugEvent {
1120                                    project_entity_id: project_entity_id,
1121                                    timestamp: Instant::now(),
1122                                    search_prompt: String::new(),
1123                                },
1124                            ))
1125                            .ok();
1126                    }
1127                    RelatedExcerptStoreEvent::FinishedRefresh {
1128                        cache_hit_count,
1129                        cache_miss_count,
1130                        mean_definition_latency,
1131                        max_definition_latency,
1132                    } => {
1133                        debug_tx
1134                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1135                                ContextRetrievalFinishedDebugEvent {
1136                                    project_entity_id: project_entity_id,
1137                                    timestamp: Instant::now(),
1138                                    metadata: vec![
1139                                        (
1140                                            "Cache Hits",
1141                                            format!(
1142                                                "{}/{}",
1143                                                cache_hit_count,
1144                                                cache_hit_count + cache_miss_count
1145                                            )
1146                                            .into(),
1147                                        ),
1148                                        (
1149                                            "Max LSP Time",
1150                                            format!("{} ms", max_definition_latency.as_millis())
1151                                                .into(),
1152                                        ),
1153                                        (
1154                                            "Mean LSP Time",
1155                                            format!("{} ms", mean_definition_latency.as_millis())
1156                                                .into(),
1157                                        ),
1158                                    ],
1159                                },
1160                            ))
1161                            .ok();
1162                    }
1163                }
1164            }
1165        }
1166    }
1167
1168    pub fn debug_info(
1169        &mut self,
1170        project: &Entity<Project>,
1171        cx: &mut Context<Self>,
1172    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1173        let project_state = self.get_or_init_project(project, cx);
1174        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1175        project_state.debug_tx = Some(debug_watch_tx);
1176        debug_watch_rx
1177    }
1178
1179    fn handle_project_event(
1180        &mut self,
1181        project: Entity<Project>,
1182        event: &project::Event,
1183        cx: &mut Context<Self>,
1184    ) {
1185        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1186            return;
1187        }
1188        // TODO [zeta2] init with recent paths
1189        match event {
1190            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1191                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1192                    return;
1193                };
1194                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1195                if let Some(path) = path {
1196                    if let Some(ix) = project_state
1197                        .recent_paths
1198                        .iter()
1199                        .position(|probe| probe == &path)
1200                    {
1201                        project_state.recent_paths.remove(ix);
1202                    }
1203                    project_state.recent_paths.push_front(path);
1204                }
1205            }
1206            project::Event::DiagnosticsUpdated { .. } => {
1207                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1208                    self.refresh_prediction_from_diagnostics(
1209                        project,
1210                        DiagnosticSearchScope::Global,
1211                        cx,
1212                    );
1213                }
1214            }
1215            _ => (),
1216        }
1217    }
1218
1219    fn register_buffer_impl<'a>(
1220        project_state: &'a mut ProjectState,
1221        buffer: &Entity<Buffer>,
1222        project: &Entity<Project>,
1223        cx: &mut Context<Self>,
1224    ) -> &'a mut RegisteredBuffer {
1225        let buffer_id = buffer.entity_id();
1226
1227        if let Some(file) = buffer.read(cx).file() {
1228            let worktree_id = file.worktree_id(cx);
1229            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1230                project_state
1231                    .license_detection_watchers
1232                    .entry(worktree_id)
1233                    .or_insert_with(|| {
1234                        let project_entity_id = project.entity_id();
1235                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1236                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1237                            else {
1238                                return;
1239                            };
1240                            project_state
1241                                .license_detection_watchers
1242                                .remove(&worktree_id);
1243                        })
1244                        .detach();
1245                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1246                    });
1247            }
1248        }
1249
1250        match project_state.registered_buffers.entry(buffer_id) {
1251            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1252            hash_map::Entry::Vacant(entry) => {
1253                let buf = buffer.read(cx);
1254                let snapshot = buf.text_snapshot();
1255                let file = buf.file().cloned();
1256                let project_entity_id = project.entity_id();
1257                entry.insert(RegisteredBuffer {
1258                    snapshot,
1259                    file,
1260                    last_position: None,
1261                    pending_predictions: Vec::new(),
1262                    _subscriptions: [
1263                        cx.subscribe(buffer, {
1264                            let project = project.downgrade();
1265                            move |this, buffer, event, cx| {
1266                                if let language::BufferEvent::Edited { is_local } = event
1267                                    && let Some(project) = project.upgrade()
1268                                {
1269                                    this.report_changes_for_buffer(
1270                                        &buffer, &project, false, *is_local, cx,
1271                                    );
1272                                }
1273                            }
1274                        }),
1275                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1276                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1277                            else {
1278                                return;
1279                            };
1280                            project_state.registered_buffers.remove(&buffer_id);
1281                        }),
1282                    ],
1283                })
1284            }
1285        }
1286    }
1287
1288    fn report_changes_for_buffer(
1289        &mut self,
1290        buffer: &Entity<Buffer>,
1291        project: &Entity<Project>,
1292        is_predicted: bool,
1293        is_local: bool,
1294        cx: &mut Context<Self>,
1295    ) {
1296        let project_state = self.get_or_init_project(project, cx);
1297        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1298
1299        let buf = buffer.read(cx);
1300        let new_file = buf.file().cloned();
1301        let new_snapshot = buf.text_snapshot();
1302        if new_snapshot.version == registered_buffer.snapshot.version {
1303            return;
1304        }
1305        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1306        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1307        let mut edit_range: Option<Range<Anchor>> = None;
1308        let now = cx.background_executor().now();
1309
1310        for (_edit, anchor_range) in
1311            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1312        {
1313            edit_range = Some(match edit_range {
1314                None => anchor_range,
1315                Some(acc) => acc.start..anchor_range.end,
1316            });
1317        }
1318
1319        let Some(edit_range) = edit_range else {
1320            return;
1321        };
1322
1323        for pending_prediction in &mut registered_buffer.pending_predictions {
1324            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1325                pending_prediction.last_edit_at = now;
1326            }
1327        }
1328
1329        let include_in_history = is_local
1330            || collaborator_edit_overlaps_locality_region(
1331                project_state,
1332                project,
1333                buffer,
1334                &buf.snapshot(),
1335                &edit_range,
1336                cx,
1337            );
1338
1339        if !include_in_history {
1340            return;
1341        }
1342
1343        let is_recordable_history_edit =
1344            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1345                .is_some();
1346
1347        let events = &mut project_state.events;
1348
1349        if !is_recordable_history_edit {
1350            if let Some(event) = project_state.last_event.take() {
1351                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1352                    if events.len() + 1 >= EVENT_COUNT_MAX {
1353                        events.pop_front();
1354                    }
1355                    events.push_back(event);
1356                }
1357            }
1358            return;
1359        }
1360
1361        if let Some(last_event) = project_state.last_event.as_mut() {
1362            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1363                == last_event.new_snapshot.remote_id()
1364                && old_snapshot.version == last_event.new_snapshot.version;
1365
1366            let prediction_source_changed = is_predicted != last_event.predicted;
1367
1368            let should_coalesce = is_next_snapshot_of_same_buffer
1369                && !prediction_source_changed
1370                && lines_between_ranges(
1371                    &edit_range.to_point(&new_snapshot),
1372                    &last_event.latest_edit_range.to_point(&new_snapshot),
1373                ) <= CHANGE_GROUPING_LINE_SPAN;
1374
1375            if should_coalesce {
1376                let pause_elapsed = last_event
1377                    .last_edit_time
1378                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1379                    .unwrap_or(false);
1380                if pause_elapsed {
1381                    last_event.snapshot_after_last_editing_pause =
1382                        Some(last_event.new_snapshot.clone());
1383                    last_event.total_edit_range_at_last_pause_boundary =
1384                        Some(last_event.total_edit_range.clone());
1385                }
1386
1387                last_event.latest_edit_range = edit_range.clone();
1388                last_event.total_edit_range =
1389                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1390                last_event.new_snapshot = new_snapshot;
1391                last_event.last_edit_time = Some(now);
1392                return;
1393            }
1394        }
1395
1396        if let Some(event) = project_state.last_event.take() {
1397            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1398                if events.len() + 1 >= EVENT_COUNT_MAX {
1399                    events.pop_front();
1400                }
1401                events.push_back(event);
1402            }
1403        }
1404
1405        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1406
1407        project_state.last_event = Some(LastEvent {
1408            old_file,
1409            new_file,
1410            old_snapshot,
1411            new_snapshot,
1412            latest_edit_range: edit_range.clone(),
1413            total_edit_range: edit_range,
1414            total_edit_range_at_last_pause_boundary: None,
1415            predicted: is_predicted,
1416            snapshot_after_last_editing_pause: None,
1417            last_edit_time: Some(now),
1418        });
1419    }
1420
1421    fn prediction_at(
1422        &mut self,
1423        buffer: &Entity<Buffer>,
1424        position: Option<language::Anchor>,
1425        project: &Entity<Project>,
1426        cx: &App,
1427    ) -> Option<BufferEditPrediction<'_>> {
1428        let project_state = self.projects.get_mut(&project.entity_id())?;
1429        if let Some(position) = position
1430            && let Some(buffer) = project_state
1431                .registered_buffers
1432                .get_mut(&buffer.entity_id())
1433        {
1434            buffer.last_position = Some(position);
1435        }
1436
1437        let CurrentEditPrediction {
1438            requested_by,
1439            prediction,
1440            ..
1441        } = project_state.current_prediction.as_ref()?;
1442
1443        if prediction.targets_buffer(buffer.read(cx)) {
1444            Some(BufferEditPrediction::Local { prediction })
1445        } else {
1446            let show_jump = match requested_by {
1447                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1448                    requested_by_buffer_id == &buffer.entity_id()
1449                }
1450                PredictionRequestedBy::DiagnosticsUpdate => true,
1451            };
1452
1453            if show_jump {
1454                Some(BufferEditPrediction::Jump { prediction })
1455            } else {
1456                None
1457            }
1458        }
1459    }
1460
1461    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1462        let Some(current_prediction) = self
1463            .projects
1464            .get_mut(&project.entity_id())
1465            .and_then(|project_state| project_state.current_prediction.take())
1466        else {
1467            return;
1468        };
1469
1470        self.report_changes_for_buffer(
1471            &current_prediction.prediction.buffer,
1472            project,
1473            true,
1474            true,
1475            cx,
1476        );
1477
1478        // can't hold &mut project_state ref across report_changes_for_buffer_call
1479        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1480            return;
1481        };
1482
1483        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1484            project_state.cancel_pending_prediction(pending_prediction, cx);
1485        }
1486
1487        match self.edit_prediction_model {
1488            EditPredictionModel::Mercury => {
1489                mercury::edit_prediction_accepted(
1490                    current_prediction.prediction.id,
1491                    self.client.http_client(),
1492                    cx,
1493                );
1494            }
1495            EditPredictionModel::Zeta => {
1496                let is_cloud = !matches!(
1497                    all_language_settings(None, cx).edit_predictions.provider,
1498                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1499                );
1500                if is_cloud {
1501                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1502                }
1503            }
1504            EditPredictionModel::Fim { .. } => {}
1505        }
1506    }
1507
1508    async fn handle_rejected_predictions(
1509        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1510        client: Arc<Client>,
1511        llm_token: LlmApiToken,
1512        app_version: Version,
1513        background_executor: BackgroundExecutor,
1514    ) {
1515        let mut rx = std::pin::pin!(rx.peekable());
1516        let mut batched = Vec::new();
1517
1518        while let Some(EditPredictionRejectionPayload {
1519            rejection,
1520            organization_id,
1521        }) = rx.next().await
1522        {
1523            batched.push(rejection);
1524
1525            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1526                select_biased! {
1527                    next = rx.as_mut().peek().fuse() => {
1528                        if next.is_some() {
1529                            continue;
1530                        }
1531                    }
1532                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1533                }
1534            }
1535
1536            let url = client
1537                .http_client()
1538                .build_zed_llm_url("/predict_edits/reject", &[])
1539                .unwrap();
1540
1541            let flush_count = batched
1542                .len()
1543                // in case items have accumulated after failure
1544                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1545            let start = batched.len() - flush_count;
1546
1547            let body = RejectEditPredictionsBodyRef {
1548                rejections: &batched[start..],
1549            };
1550
1551            let result = Self::send_api_request::<()>(
1552                |builder| {
1553                    let req = builder
1554                        .uri(url.as_ref())
1555                        .body(serde_json::to_string(&body)?.into());
1556                    anyhow::Ok(req?)
1557                },
1558                client.clone(),
1559                llm_token.clone(),
1560                organization_id,
1561                app_version.clone(),
1562                true,
1563            )
1564            .await;
1565
1566            if result.log_err().is_some() {
1567                batched.drain(start..);
1568            }
1569        }
1570    }
1571
1572    async fn run_settled_predictions_worker(
1573        this: WeakEntity<Self>,
1574        mut rx: UnboundedReceiver<Instant>,
1575        cx: &mut AsyncApp,
1576    ) {
1577        let mut next_wake_time: Option<Instant> = None;
1578        loop {
1579            let now = cx.background_executor().now();
1580            if let Some(wake_time) = next_wake_time.take() {
1581                cx.background_executor()
1582                    .timer(wake_time.duration_since(now))
1583                    .await;
1584            } else {
1585                let Some(new_enqueue_time) = rx.next().await else {
1586                    break;
1587                };
1588                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1589                while rx.next().now_or_never().flatten().is_some() {}
1590                continue;
1591            }
1592
1593            let Some(this) = this.upgrade() else {
1594                break;
1595            };
1596
1597            let now = cx.background_executor().now();
1598
1599            let mut oldest_edited_at = None;
1600
1601            this.update(cx, |this, _| {
1602                for (_, project_state) in this.projects.iter_mut() {
1603                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1604                        registered_buffer
1605                            .pending_predictions
1606                            .retain_mut(|pending_prediction| {
1607                                let age =
1608                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1609                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1610                                    return false;
1611                                }
1612
1613                                let quiet_for =
1614                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1615                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1616                                    let settled_editable_region = registered_buffer
1617                                        .snapshot
1618                                        .text_for_range(
1619                                            pending_prediction.editable_anchor_range.clone(),
1620                                        )
1621                                        .collect::<String>();
1622
1623                                    #[cfg(test)]
1624                                    if let Some(callback) = &this.settled_event_callback {
1625                                        callback(
1626                                            pending_prediction.request_id.clone(),
1627                                            settled_editable_region.clone(),
1628                                        );
1629                                    }
1630
1631                                    telemetry::event!(
1632                                        EDIT_PREDICTION_SETTLED_EVENT,
1633                                        request_id = pending_prediction.request_id.0.clone(),
1634                                        settled_editable_region,
1635                                        example = pending_prediction.example.take(),
1636                                        e2e_latency = pending_prediction.e2e_latency.as_millis(),
1637                                    );
1638
1639                                    return false;
1640                                }
1641
1642                                if oldest_edited_at
1643                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1644                                {
1645                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1646                                }
1647
1648                                true
1649                            });
1650                    }
1651                }
1652            });
1653
1654            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1655        }
1656    }
1657
1658    pub(crate) fn enqueue_settled_prediction(
1659        &mut self,
1660        request_id: EditPredictionId,
1661        project: &Entity<Project>,
1662        edited_buffer: &Entity<Buffer>,
1663        edited_buffer_snapshot: &BufferSnapshot,
1664        editable_offset_range: Range<usize>,
1665        example: Option<ExampleSpec>,
1666        e2e_latency: std::time::Duration,
1667        cx: &mut Context<Self>,
1668    ) {
1669        let this = &mut *self;
1670        let project_state = this.get_or_init_project(project, cx);
1671        if let Some(buffer) = project_state
1672            .registered_buffers
1673            .get_mut(&edited_buffer.entity_id())
1674        {
1675            let now = cx.background_executor().now();
1676            buffer.pending_predictions.push(PendingSettledPrediction {
1677                request_id: request_id,
1678                editable_anchor_range: edited_buffer_snapshot
1679                    .anchor_range_around(editable_offset_range),
1680                example,
1681                e2e_latency,
1682                enqueued_at: now,
1683                last_edit_at: now,
1684            });
1685            this.settled_predictions_tx.unbounded_send(now).ok();
1686        }
1687    }
1688
1689    fn reject_current_prediction(
1690        &mut self,
1691        reason: EditPredictionRejectReason,
1692        project: &Entity<Project>,
1693        cx: &App,
1694    ) {
1695        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1696            project_state.pending_predictions.clear();
1697            if let Some(prediction) = project_state.current_prediction.take() {
1698                let model_version = prediction.prediction.model_version.clone();
1699                self.reject_prediction(
1700                    prediction.prediction.id,
1701                    reason,
1702                    prediction.was_shown,
1703                    model_version,
1704                    Some(prediction.e2e_latency),
1705                    cx,
1706                );
1707            }
1708        };
1709    }
1710
1711    fn did_show_current_prediction(
1712        &mut self,
1713        project: &Entity<Project>,
1714        display_type: edit_prediction_types::SuggestionDisplayType,
1715        _cx: &mut Context<Self>,
1716    ) {
1717        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1718            return;
1719        };
1720
1721        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1722            return;
1723        };
1724
1725        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1726        let previous_shown_with = current_prediction.shown_with;
1727
1728        if previous_shown_with.is_none() || !is_jump {
1729            current_prediction.shown_with = Some(display_type);
1730        }
1731
1732        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1733
1734        if is_first_non_jump_show {
1735            current_prediction.was_shown = true;
1736        }
1737
1738        if is_first_non_jump_show {
1739            self.shown_predictions
1740                .push_front(current_prediction.prediction.clone());
1741            if self.shown_predictions.len() > 50 {
1742                let completion = self.shown_predictions.pop_back().unwrap();
1743                self.rated_predictions.remove(&completion.id);
1744            }
1745        }
1746    }
1747
1748    fn reject_prediction(
1749        &mut self,
1750        prediction_id: EditPredictionId,
1751        reason: EditPredictionRejectReason,
1752        was_shown: bool,
1753        model_version: Option<String>,
1754        e2e_latency: Option<std::time::Duration>,
1755        cx: &App,
1756    ) {
1757        match self.edit_prediction_model {
1758            EditPredictionModel::Zeta => {
1759                let is_cloud = !matches!(
1760                    all_language_settings(None, cx).edit_predictions.provider,
1761                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1762                );
1763
1764                if is_cloud {
1765                    let organization_id = self
1766                        .user_store
1767                        .read(cx)
1768                        .current_organization()
1769                        .map(|organization| organization.id.clone());
1770
1771                    self.reject_predictions_tx
1772                        .unbounded_send(EditPredictionRejectionPayload {
1773                            rejection: EditPredictionRejection {
1774                                request_id: prediction_id.to_string(),
1775                                reason,
1776                                was_shown,
1777                                model_version,
1778                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1779                            },
1780                            organization_id,
1781                        })
1782                        .log_err();
1783                }
1784            }
1785            EditPredictionModel::Mercury => {
1786                mercury::edit_prediction_rejected(
1787                    prediction_id,
1788                    was_shown,
1789                    reason,
1790                    self.client.http_client(),
1791                    cx,
1792                );
1793            }
1794            EditPredictionModel::Fim { .. } => {}
1795        }
1796    }
1797
1798    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1799        self.projects
1800            .get(&project.entity_id())
1801            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1802    }
1803
1804    pub fn refresh_prediction_from_buffer(
1805        &mut self,
1806        project: Entity<Project>,
1807        buffer: Entity<Buffer>,
1808        position: language::Anchor,
1809        cx: &mut Context<Self>,
1810    ) {
1811        self.queue_prediction_refresh(
1812            project.clone(),
1813            PredictEditsRequestTrigger::Other,
1814            buffer.entity_id(),
1815            cx,
1816            move |this, cx| {
1817                let Some(request_task) = this
1818                    .update(cx, |this, cx| {
1819                        this.request_prediction(
1820                            &project,
1821                            &buffer,
1822                            position,
1823                            PredictEditsRequestTrigger::Other,
1824                            cx,
1825                        )
1826                    })
1827                    .log_err()
1828                else {
1829                    return Task::ready(anyhow::Ok(None));
1830                };
1831
1832                cx.spawn(async move |_cx| {
1833                    request_task.await.map(|prediction_result| {
1834                        prediction_result.map(|prediction_result| {
1835                            (
1836                                prediction_result,
1837                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1838                            )
1839                        })
1840                    })
1841                })
1842            },
1843        )
1844    }
1845
1846    pub fn refresh_prediction_from_diagnostics(
1847        &mut self,
1848        project: Entity<Project>,
1849        scope: DiagnosticSearchScope,
1850        cx: &mut Context<Self>,
1851    ) {
1852        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1853            return;
1854        }
1855
1856        if currently_following(&project, cx) {
1857            return;
1858        }
1859
1860        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1861            return;
1862        };
1863
1864        // Prefer predictions from buffer
1865        if project_state.current_prediction.is_some() {
1866            log::debug!(
1867                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1868            );
1869            return;
1870        }
1871
1872        self.queue_prediction_refresh(
1873            project.clone(),
1874            PredictEditsRequestTrigger::Diagnostics,
1875            project.entity_id(),
1876            cx,
1877            move |this, cx| {
1878                let Some((active_buffer, snapshot, cursor_point)) = this
1879                    .read_with(cx, |this, cx| {
1880                        let project_state = this.projects.get(&project.entity_id())?;
1881                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1882                        let snapshot = buffer.read(cx).snapshot();
1883
1884                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1885                            return None;
1886                        }
1887
1888                        let cursor_point = position
1889                            .map(|pos| pos.to_point(&snapshot))
1890                            .unwrap_or_default();
1891
1892                        Some((buffer, snapshot, cursor_point))
1893                    })
1894                    .log_err()
1895                    .flatten()
1896                else {
1897                    return Task::ready(anyhow::Ok(None));
1898                };
1899
1900                cx.spawn(async move |cx| {
1901                    let diagnostic_search_range = match scope {
1902                        DiagnosticSearchScope::Local => {
1903                            let diagnostic_search_start =
1904                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1905                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1906                            Point::new(diagnostic_search_start, 0)
1907                                ..Point::new(diagnostic_search_end, 0)
1908                        }
1909                        DiagnosticSearchScope::Global => Default::default(),
1910                    };
1911
1912                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1913                        active_buffer,
1914                        &snapshot,
1915                        diagnostic_search_range,
1916                        cursor_point,
1917                        &project,
1918                        cx,
1919                    )
1920                    .await?
1921                    else {
1922                        return anyhow::Ok(None);
1923                    };
1924
1925                    let Some(prediction_result) = this
1926                        .update(cx, |this, cx| {
1927                            this.request_prediction(
1928                                &project,
1929                                &jump_buffer,
1930                                jump_position,
1931                                PredictEditsRequestTrigger::Diagnostics,
1932                                cx,
1933                            )
1934                        })?
1935                        .await?
1936                    else {
1937                        return anyhow::Ok(None);
1938                    };
1939
1940                    this.update(cx, |this, cx| {
1941                        Some((
1942                            if this
1943                                .get_or_init_project(&project, cx)
1944                                .current_prediction
1945                                .is_none()
1946                            {
1947                                prediction_result
1948                            } else {
1949                                EditPredictionResult {
1950                                    id: prediction_result.id,
1951                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1952                                    e2e_latency: prediction_result.e2e_latency,
1953                                }
1954                            },
1955                            PredictionRequestedBy::DiagnosticsUpdate,
1956                        ))
1957                    })
1958                })
1959            },
1960        );
1961    }
1962
1963    fn predictions_enabled_at(
1964        snapshot: &BufferSnapshot,
1965        position: Option<language::Anchor>,
1966        cx: &App,
1967    ) -> bool {
1968        let file = snapshot.file();
1969        let all_settings = all_language_settings(file, cx);
1970        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1971            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1972        {
1973            return false;
1974        }
1975
1976        if let Some(last_position) = position {
1977            let settings = snapshot.settings_at(last_position, cx);
1978
1979            if !settings.edit_predictions_disabled_in.is_empty()
1980                && let Some(scope) = snapshot.language_scope_at(last_position)
1981                && let Some(scope_name) = scope.override_name()
1982                && settings
1983                    .edit_predictions_disabled_in
1984                    .iter()
1985                    .any(|s| s == scope_name)
1986            {
1987                return false;
1988            }
1989        }
1990
1991        true
1992    }
1993
1994    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1995}
1996
1997fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
1998    let Some(app_state) = AppState::try_global(cx) else {
1999        return false;
2000    };
2001
2002    app_state
2003        .workspace_store
2004        .read(cx)
2005        .workspaces()
2006        .filter_map(|workspace| workspace.upgrade())
2007        .any(|workspace| {
2008            workspace.read(cx).project().entity_id() == project.entity_id()
2009                && workspace
2010                    .read(cx)
2011                    .leader_for_pane(workspace.read(cx).active_pane())
2012                    .is_some()
2013        })
2014}
2015
2016fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2017    match provider {
2018        EditPredictionProvider::Zed
2019        | EditPredictionProvider::Mercury
2020        | EditPredictionProvider::Ollama
2021        | EditPredictionProvider::OpenAiCompatibleApi
2022        | EditPredictionProvider::Experimental(_) => true,
2023        EditPredictionProvider::None
2024        | EditPredictionProvider::Copilot
2025        | EditPredictionProvider::Codestral => false,
2026    }
2027}
2028
2029impl EditPredictionStore {
2030    fn queue_prediction_refresh(
2031        &mut self,
2032        project: Entity<Project>,
2033        request_trigger: PredictEditsRequestTrigger,
2034        throttle_entity: EntityId,
2035        cx: &mut Context<Self>,
2036        do_refresh: impl FnOnce(
2037            WeakEntity<Self>,
2038            &mut AsyncApp,
2039        )
2040            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2041        + 'static,
2042    ) {
2043        fn select_throttle(
2044            project_state: &mut ProjectState,
2045            request_trigger: PredictEditsRequestTrigger,
2046        ) -> &mut Option<(EntityId, Instant)> {
2047            match request_trigger {
2048                PredictEditsRequestTrigger::Diagnostics => {
2049                    &mut project_state.last_jump_prediction_refresh
2050                }
2051                _ => &mut project_state.last_edit_prediction_refresh,
2052            }
2053        }
2054
2055        let (needs_acceptance_tracking, max_pending_predictions) =
2056            match all_language_settings(None, cx).edit_predictions.provider {
2057                EditPredictionProvider::Zed
2058                | EditPredictionProvider::Mercury
2059                | EditPredictionProvider::Experimental(_) => (true, 2),
2060                EditPredictionProvider::Ollama => (false, 1),
2061                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2062                EditPredictionProvider::None
2063                | EditPredictionProvider::Copilot
2064                | EditPredictionProvider::Codestral => {
2065                    log::error!("queue_prediction_refresh called with non-store provider");
2066                    return;
2067                }
2068            };
2069
2070        let drop_on_cancel = !needs_acceptance_tracking;
2071        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2072        let project_state = self.get_or_init_project(&project, cx);
2073        let pending_prediction_id = project_state.next_pending_prediction_id;
2074        project_state.next_pending_prediction_id += 1;
2075        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2076
2077        let task = cx.spawn(async move |this, cx| {
2078            let throttle_wait = this
2079                .update(cx, |this, cx| {
2080                    let project_state = this.get_or_init_project(&project, cx);
2081                    let throttle = *select_throttle(project_state, request_trigger);
2082
2083                    let now = cx.background_executor().now();
2084                    throttle.and_then(|(last_entity, last_timestamp)| {
2085                        if throttle_entity != last_entity {
2086                            return None;
2087                        }
2088                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2089                    })
2090                })
2091                .ok()
2092                .flatten();
2093
2094            if let Some(timeout) = throttle_wait {
2095                cx.background_executor().timer(timeout).await;
2096            }
2097
2098            // If this task was cancelled before the throttle timeout expired,
2099            // do not perform a request. Also skip if another task already
2100            // proceeded since we were enqueued (duplicate).
2101            let mut is_cancelled = true;
2102            this.update(cx, |this, cx| {
2103                let project_state = this.get_or_init_project(&project, cx);
2104                let was_cancelled = project_state
2105                    .cancelled_predictions
2106                    .remove(&pending_prediction_id);
2107                if was_cancelled {
2108                    return;
2109                }
2110
2111                // Another request has been already sent since this was enqueued
2112                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2113                    return;
2114                }
2115
2116                let new_refresh = (throttle_entity, cx.background_executor().now());
2117                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2118                is_cancelled = false;
2119            })
2120            .ok();
2121            if is_cancelled {
2122                return None;
2123            }
2124
2125            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2126            let new_prediction_id = new_prediction_result
2127                .as_ref()
2128                .map(|(prediction, _)| prediction.id.clone());
2129
2130            // When a prediction completes, remove it from the pending list, and cancel
2131            // any pending predictions that were enqueued before it.
2132            this.update(cx, |this, cx| {
2133                let project_state = this.get_or_init_project(&project, cx);
2134
2135                let is_cancelled = project_state
2136                    .cancelled_predictions
2137                    .remove(&pending_prediction_id);
2138
2139                let new_current_prediction = if !is_cancelled
2140                    && let Some((prediction_result, requested_by)) = new_prediction_result
2141                {
2142                    match prediction_result.prediction {
2143                        Ok(prediction) => {
2144                            let new_prediction = CurrentEditPrediction {
2145                                requested_by,
2146                                prediction,
2147                                was_shown: false,
2148                                shown_with: None,
2149                                e2e_latency: prediction_result.e2e_latency,
2150                            };
2151
2152                            if let Some(current_prediction) =
2153                                project_state.current_prediction.as_ref()
2154                            {
2155                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2156                                {
2157                                    this.reject_current_prediction(
2158                                        EditPredictionRejectReason::Replaced,
2159                                        &project,
2160                                        cx,
2161                                    );
2162
2163                                    Some(new_prediction)
2164                                } else {
2165                                    this.reject_prediction(
2166                                        new_prediction.prediction.id,
2167                                        EditPredictionRejectReason::CurrentPreferred,
2168                                        false,
2169                                        new_prediction.prediction.model_version,
2170                                        Some(new_prediction.e2e_latency),
2171                                        cx,
2172                                    );
2173                                    None
2174                                }
2175                            } else {
2176                                Some(new_prediction)
2177                            }
2178                        }
2179                        Err(reject_reason) => {
2180                            this.reject_prediction(
2181                                prediction_result.id,
2182                                reject_reason,
2183                                false,
2184                                None,
2185                                Some(prediction_result.e2e_latency),
2186                                cx,
2187                            );
2188                            None
2189                        }
2190                    }
2191                } else {
2192                    None
2193                };
2194
2195                let project_state = this.get_or_init_project(&project, cx);
2196
2197                if let Some(new_prediction) = new_current_prediction {
2198                    project_state.current_prediction = Some(new_prediction);
2199                }
2200
2201                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2202                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2203                    if pending_prediction.id == pending_prediction_id {
2204                        pending_predictions.remove(ix);
2205                        for pending_prediction in pending_predictions.drain(0..ix) {
2206                            project_state.cancel_pending_prediction(pending_prediction, cx)
2207                        }
2208                        break;
2209                    }
2210                }
2211                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2212                cx.notify();
2213            })
2214            .ok();
2215
2216            new_prediction_id
2217        });
2218
2219        if project_state.pending_predictions.len() < max_pending_predictions {
2220            project_state
2221                .pending_predictions
2222                .push(PendingPrediction {
2223                    id: pending_prediction_id,
2224                    task,
2225                    drop_on_cancel,
2226                })
2227                .unwrap();
2228        } else {
2229            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2230            project_state
2231                .pending_predictions
2232                .push(PendingPrediction {
2233                    id: pending_prediction_id,
2234                    task,
2235                    drop_on_cancel,
2236                })
2237                .unwrap();
2238            project_state.cancel_pending_prediction(pending_prediction, cx);
2239        }
2240    }
2241
2242    pub fn request_prediction(
2243        &mut self,
2244        project: &Entity<Project>,
2245        active_buffer: &Entity<Buffer>,
2246        position: language::Anchor,
2247        trigger: PredictEditsRequestTrigger,
2248        cx: &mut Context<Self>,
2249    ) -> Task<Result<Option<EditPredictionResult>>> {
2250        self.request_prediction_internal(
2251            project.clone(),
2252            active_buffer.clone(),
2253            position,
2254            trigger,
2255            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2256            cx,
2257        )
2258    }
2259
2260    fn request_prediction_internal(
2261        &mut self,
2262        project: Entity<Project>,
2263        active_buffer: Entity<Buffer>,
2264        position: language::Anchor,
2265        trigger: PredictEditsRequestTrigger,
2266        allow_jump: bool,
2267        cx: &mut Context<Self>,
2268    ) -> Task<Result<Option<EditPredictionResult>>> {
2269        self.get_or_init_project(&project, cx);
2270        let project_state = self.projects.get(&project.entity_id()).unwrap();
2271        let stored_events = project_state.events(cx);
2272        let has_events = !stored_events.is_empty();
2273        let events: Vec<Arc<zeta_prompt::Event>> =
2274            stored_events.iter().map(|e| e.event.clone()).collect();
2275        let debug_tx = project_state.debug_tx.clone();
2276
2277        let snapshot = active_buffer.read(cx).snapshot();
2278        let cursor_point = position.to_point(&snapshot);
2279        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2280        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2281        let diagnostic_search_range =
2282            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2283
2284        let related_files = self.context_for_project(&project, cx);
2285
2286        let is_open_source = snapshot
2287            .file()
2288            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2289            && events.iter().all(|event| event.in_open_source_repo())
2290            && related_files.iter().all(|file| file.in_open_source_repo);
2291
2292        let can_collect_data = !cfg!(test)
2293            && is_open_source
2294            && self.is_data_collection_enabled(cx)
2295            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2296
2297        let inputs = EditPredictionModelInput {
2298            project: project.clone(),
2299            buffer: active_buffer,
2300            snapshot,
2301            position,
2302            events,
2303            related_files,
2304            trigger,
2305            diagnostic_search_range: diagnostic_search_range,
2306            debug_tx,
2307            can_collect_data,
2308            is_open_source,
2309        };
2310
2311        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2312
2313        let task = match self.edit_prediction_model {
2314            EditPredictionModel::Zeta => {
2315                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2316            }
2317            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2318            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2319        };
2320
2321        cx.spawn(async move |this, cx| {
2322            let prediction = task.await?;
2323
2324            // Only fall back to diagnostics-based prediction if we got a
2325            // the model had nothing to suggest for the buffer
2326            if prediction.is_none()
2327                && allow_jump
2328                && has_events
2329                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2330            {
2331                this.update(cx, |this, cx| {
2332                    this.refresh_prediction_from_diagnostics(
2333                        project,
2334                        DiagnosticSearchScope::Local,
2335                        cx,
2336                    );
2337                })?;
2338                return anyhow::Ok(None);
2339            }
2340
2341            Ok(prediction)
2342        })
2343    }
2344
2345    pub(crate) async fn next_diagnostic_location(
2346        active_buffer: Entity<Buffer>,
2347        active_buffer_snapshot: &BufferSnapshot,
2348        active_buffer_diagnostic_search_range: Range<Point>,
2349        active_buffer_cursor_point: Point,
2350        project: &Entity<Project>,
2351        cx: &mut AsyncApp,
2352    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2353        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2354            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2355            .flat_map(|(_, _, _, selections)| {
2356                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2357            })
2358            .collect();
2359
2360        let mut jump_location = active_buffer_snapshot
2361            .diagnostic_groups(None)
2362            .into_iter()
2363            .filter_map(|(_, group)| {
2364                let range = &group.entries[group.primary_ix]
2365                    .range
2366                    .to_point(&active_buffer_snapshot);
2367                if range.overlaps(&active_buffer_diagnostic_search_range) {
2368                    return None;
2369                }
2370                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2371                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2372                });
2373                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2374                    <= DIAGNOSTIC_LINES_RANGE;
2375                if near_collaborator && !near_local {
2376                    return None;
2377                }
2378                Some(range.start)
2379            })
2380            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2381            .map(|position| {
2382                (
2383                    active_buffer.clone(),
2384                    active_buffer_snapshot.anchor_before(position),
2385                )
2386            });
2387
2388        if jump_location.is_none() {
2389            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2390                let file = buffer.file()?;
2391
2392                Some(ProjectPath {
2393                    worktree_id: file.worktree_id(cx),
2394                    path: file.path().clone(),
2395                })
2396            });
2397
2398            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2399                project
2400                    .diagnostic_summaries(false, cx)
2401                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2402                    .map(|(path, _, _)| {
2403                        let shared_prefix = path
2404                            .path
2405                            .components()
2406                            .zip(
2407                                active_buffer_path
2408                                    .as_ref()
2409                                    .map(|p| p.path.components())
2410                                    .unwrap_or_default(),
2411                            )
2412                            .take_while(|(a, b)| a == b)
2413                            .count();
2414                        (path, shared_prefix)
2415                    })
2416                    .collect()
2417            });
2418
2419            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2420
2421            for (path, _) in candidates {
2422                let candidate_buffer = project
2423                    .update(cx, |project, cx| project.open_buffer(path, cx))
2424                    .await?;
2425
2426                let (has_collaborators, diagnostic_position) =
2427                    candidate_buffer.read_with(cx, |buffer, _cx| {
2428                        let snapshot = buffer.snapshot();
2429                        let has_collaborators = snapshot
2430                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2431                            .next()
2432                            .is_some();
2433                        let position = buffer
2434                            .buffer_diagnostics(None)
2435                            .into_iter()
2436                            .min_by_key(|entry| entry.diagnostic.severity)
2437                            .map(|entry| entry.range.start);
2438                        (has_collaborators, position)
2439                    });
2440
2441                if has_collaborators {
2442                    continue;
2443                }
2444
2445                if let Some(position) = diagnostic_position {
2446                    jump_location = Some((candidate_buffer, position));
2447                    break;
2448                }
2449            }
2450        }
2451
2452        anyhow::Ok(jump_location)
2453    }
2454
2455    async fn send_raw_llm_request(
2456        request: RawCompletionRequest,
2457        client: Arc<Client>,
2458        custom_url: Option<Arc<Url>>,
2459        llm_token: LlmApiToken,
2460        organization_id: Option<OrganizationId>,
2461        app_version: Version,
2462    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2463        let url = if let Some(custom_url) = custom_url {
2464            custom_url.as_ref().clone()
2465        } else {
2466            client
2467                .http_client()
2468                .build_zed_llm_url("/predict_edits/raw", &[])?
2469        };
2470
2471        Self::send_api_request(
2472            |builder| {
2473                let req = builder
2474                    .uri(url.as_ref())
2475                    .body(serde_json::to_string(&request)?.into());
2476                Ok(req?)
2477            },
2478            client,
2479            llm_token,
2480            organization_id,
2481            app_version,
2482            true,
2483        )
2484        .await
2485    }
2486
2487    pub(crate) async fn send_v3_request(
2488        input: ZetaPromptInput,
2489        client: Arc<Client>,
2490        llm_token: LlmApiToken,
2491        organization_id: Option<OrganizationId>,
2492        app_version: Version,
2493        trigger: PredictEditsRequestTrigger,
2494    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2495        let url = client
2496            .http_client()
2497            .build_zed_llm_url("/predict_edits/v3", &[])?;
2498
2499        let request = PredictEditsV3Request { input, trigger };
2500
2501        let json_bytes = serde_json::to_vec(&request)?;
2502        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2503
2504        Self::send_api_request(
2505            |builder| {
2506                let req = builder
2507                    .uri(url.as_ref())
2508                    .header("Content-Encoding", "zstd")
2509                    .body(compressed.clone().into());
2510                Ok(req?)
2511            },
2512            client,
2513            llm_token,
2514            organization_id,
2515            app_version,
2516            true,
2517        )
2518        .await
2519    }
2520
2521    async fn send_api_request<Res>(
2522        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2523        client: Arc<Client>,
2524        llm_token: LlmApiToken,
2525        organization_id: Option<OrganizationId>,
2526        app_version: Version,
2527        require_auth: bool,
2528    ) -> Result<(Res, Option<EditPredictionUsage>)>
2529    where
2530        Res: DeserializeOwned,
2531    {
2532        let http_client = client.http_client();
2533
2534        let mut token = if require_auth {
2535            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2536        } else {
2537            llm_token
2538                .acquire(&client, organization_id.clone())
2539                .await
2540                .ok()
2541        };
2542        let mut did_retry = false;
2543
2544        loop {
2545            let request_builder = http_client::Request::builder().method(Method::POST);
2546
2547            let mut request_builder = request_builder
2548                .header("Content-Type", "application/json")
2549                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2550
2551            // Only add Authorization header if we have a token
2552            if let Some(ref token_value) = token {
2553                request_builder =
2554                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2555            }
2556
2557            let request = build(request_builder)?;
2558
2559            let mut response = http_client.send(request).await?;
2560
2561            if let Some(minimum_required_version) = response
2562                .headers()
2563                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2564                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2565            {
2566                anyhow::ensure!(
2567                    app_version >= minimum_required_version,
2568                    ZedUpdateRequiredError {
2569                        minimum_version: minimum_required_version
2570                    }
2571                );
2572            }
2573
2574            if response.status().is_success() {
2575                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2576
2577                let mut body = Vec::new();
2578                response.body_mut().read_to_end(&mut body).await?;
2579                return Ok((serde_json::from_slice(&body)?, usage));
2580            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2581                did_retry = true;
2582                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2583            } else {
2584                let mut body = String::new();
2585                response.body_mut().read_to_string(&mut body).await?;
2586                anyhow::bail!(
2587                    "Request failed with status: {:?}\nBody: {}",
2588                    response.status(),
2589                    body
2590                );
2591            }
2592        }
2593    }
2594
2595    pub fn refresh_context(
2596        &mut self,
2597        project: &Entity<Project>,
2598        buffer: &Entity<language::Buffer>,
2599        cursor_position: language::Anchor,
2600        cx: &mut Context<Self>,
2601    ) {
2602        self.get_or_init_project(project, cx)
2603            .context
2604            .update(cx, |store, cx| {
2605                store.refresh(buffer.clone(), cursor_position, cx);
2606            });
2607    }
2608
2609    #[cfg(feature = "cli-support")]
2610    pub fn set_context_for_buffer(
2611        &mut self,
2612        project: &Entity<Project>,
2613        related_files: Vec<RelatedFile>,
2614        cx: &mut Context<Self>,
2615    ) {
2616        self.get_or_init_project(project, cx)
2617            .context
2618            .update(cx, |store, cx| {
2619                store.set_related_files(related_files, cx);
2620            });
2621    }
2622
2623    #[cfg(feature = "cli-support")]
2624    pub fn set_recent_paths_for_project(
2625        &mut self,
2626        project: &Entity<Project>,
2627        paths: impl IntoIterator<Item = project::ProjectPath>,
2628        cx: &mut Context<Self>,
2629    ) {
2630        let project_state = self.get_or_init_project(project, cx);
2631        project_state.recent_paths = paths.into_iter().collect();
2632    }
2633
2634    fn is_file_open_source(
2635        &self,
2636        project: &Entity<Project>,
2637        file: &Arc<dyn File>,
2638        cx: &App,
2639    ) -> bool {
2640        if !file.is_local() || file.is_private() {
2641            return false;
2642        }
2643        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2644            return false;
2645        };
2646        project_state
2647            .license_detection_watchers
2648            .get(&file.worktree_id(cx))
2649            .as_ref()
2650            .is_some_and(|watcher| watcher.is_project_open_source())
2651    }
2652
2653    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2654        self.data_collection_choice.is_enabled(cx)
2655    }
2656
2657    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2658        let choice = KeyValueStore::global(cx)
2659            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2660            .log_err()
2661            .flatten();
2662
2663        match choice.as_deref() {
2664            Some("true") => DataCollectionChoice::Enabled,
2665            Some("false") => DataCollectionChoice::Disabled,
2666            Some(_) => {
2667                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2668                DataCollectionChoice::NotAnswered
2669            }
2670            None => DataCollectionChoice::NotAnswered,
2671        }
2672    }
2673
2674    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2675        self.data_collection_choice = self.data_collection_choice.toggle();
2676        let new_choice = self.data_collection_choice;
2677        let is_enabled = new_choice.is_enabled(cx);
2678        let kvp = KeyValueStore::global(cx);
2679        db::write_and_log(cx, move || async move {
2680            kvp.write_kvp(
2681                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2682                is_enabled.to_string(),
2683            )
2684            .await
2685        });
2686    }
2687
2688    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2689        self.shown_predictions.iter()
2690    }
2691
2692    pub fn shown_completions_len(&self) -> usize {
2693        self.shown_predictions.len()
2694    }
2695
2696    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2697        self.rated_predictions.contains(id)
2698    }
2699
2700    pub fn rate_prediction(
2701        &mut self,
2702        prediction: &EditPrediction,
2703        rating: EditPredictionRating,
2704        feedback: String,
2705        cx: &mut Context<Self>,
2706    ) {
2707        let organization = self.user_store.read(cx).current_organization();
2708
2709        self.rated_predictions.insert(prediction.id.clone());
2710
2711        cx.background_spawn({
2712            let client = self.client.clone();
2713            let prediction_id = prediction.id.to_string();
2714            let inputs = serde_json::to_value(&prediction.inputs);
2715            let output = prediction
2716                .edit_preview
2717                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2718            async move {
2719                client
2720                    .cloud_client()
2721                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2722                        organization_id: organization.map(|organization| organization.id.clone()),
2723                        request_id: prediction_id,
2724                        rating: match rating {
2725                            EditPredictionRating::Positive => "positive".to_string(),
2726                            EditPredictionRating::Negative => "negative".to_string(),
2727                        },
2728                        inputs: inputs?,
2729                        output,
2730                        feedback,
2731                    })
2732                    .await?;
2733
2734                anyhow::Ok(())
2735            }
2736        })
2737        .detach_and_log_err(cx);
2738
2739        cx.notify();
2740    }
2741}
2742
2743fn collaborator_edit_overlaps_locality_region(
2744    project_state: &ProjectState,
2745    project: &Entity<Project>,
2746    buffer: &Entity<Buffer>,
2747    snapshot: &BufferSnapshot,
2748    edit_range: &Range<Anchor>,
2749    cx: &App,
2750) -> bool {
2751    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2752        return false;
2753    };
2754
2755    if active_buffer.entity_id() != buffer.entity_id() {
2756        return false;
2757    }
2758
2759    let locality_point_range = expand_context_syntactically_then_linewise(
2760        snapshot,
2761        (position..position).to_point(snapshot),
2762        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2763    );
2764    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2765
2766    edit_range.overlaps(&locality_anchor_range, snapshot)
2767}
2768
2769fn merge_trailing_events_if_needed(
2770    events: &mut VecDeque<StoredEvent>,
2771    end_snapshot: &TextBufferSnapshot,
2772    latest_snapshot: &TextBufferSnapshot,
2773    latest_edit_range: &Range<Anchor>,
2774) {
2775    if let Some(last_event) = events.back() {
2776        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2777            return;
2778        }
2779        if !latest_snapshot
2780            .version
2781            .observed_all(&last_event.new_snapshot_version)
2782        {
2783            return;
2784        }
2785    }
2786
2787    let mut next_old_event = None;
2788    let mut mergeable_count = 0;
2789    for old_event in events.iter().rev() {
2790        if let Some(next_old_event) = next_old_event
2791            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2792        {
2793            break;
2794        }
2795        mergeable_count += 1;
2796        next_old_event = Some(old_event);
2797    }
2798
2799    if mergeable_count <= 1 {
2800        return;
2801    }
2802
2803    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2804    let oldest_event = events_to_merge.peek().unwrap();
2805    let oldest_snapshot = oldest_event.old_snapshot.clone();
2806    let newest_snapshot = end_snapshot;
2807    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2808
2809    for event in events.range(events.len() - mergeable_count + 1..) {
2810        merged_edit_range =
2811            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2812    }
2813
2814    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2815        &oldest_snapshot,
2816        newest_snapshot,
2817        &merged_edit_range,
2818    ) {
2819        let merged_event = match oldest_event.event.as_ref() {
2820            zeta_prompt::Event::BufferChange {
2821                old_path,
2822                path,
2823                in_open_source_repo,
2824                ..
2825            } => StoredEvent {
2826                event: Arc::new(zeta_prompt::Event::BufferChange {
2827                    old_path: old_path.clone(),
2828                    path: path.clone(),
2829                    diff,
2830                    in_open_source_repo: *in_open_source_repo,
2831                    predicted: events_to_merge.all(|e| {
2832                        matches!(
2833                            e.event.as_ref(),
2834                            zeta_prompt::Event::BufferChange {
2835                                predicted: true,
2836                                ..
2837                            }
2838                        )
2839                    }),
2840                }),
2841                old_snapshot: oldest_snapshot.clone(),
2842                new_snapshot_version: newest_snapshot.version.clone(),
2843                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2844                    ..newest_snapshot.anchor_before(edit_range.end),
2845            },
2846        };
2847        events.truncate(events.len() - mergeable_count);
2848        events.push_back(merged_event);
2849    }
2850}
2851
2852fn merge_anchor_ranges(
2853    left: &Range<Anchor>,
2854    right: &Range<Anchor>,
2855    snapshot: &TextBufferSnapshot,
2856) -> Range<Anchor> {
2857    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2858        left.start
2859    } else {
2860        right.start
2861    };
2862    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2863        left.end
2864    } else {
2865        right.end
2866    };
2867    start..end
2868}
2869
2870#[derive(Error, Debug)]
2871#[error(
2872    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2873)]
2874pub struct ZedUpdateRequiredError {
2875    minimum_version: Version,
2876}
2877
2878#[derive(Debug, Clone, Copy)]
2879pub enum DataCollectionChoice {
2880    NotAnswered,
2881    Enabled,
2882    Disabled,
2883}
2884
2885impl DataCollectionChoice {
2886    pub fn is_enabled(self, cx: &App) -> bool {
2887        if cx.is_staff() {
2888            return true;
2889        }
2890        match self {
2891            Self::Enabled => true,
2892            Self::NotAnswered | Self::Disabled => false,
2893        }
2894    }
2895
2896    #[must_use]
2897    pub fn toggle(&self) -> DataCollectionChoice {
2898        match self {
2899            Self::Enabled => Self::Disabled,
2900            Self::Disabled => Self::Enabled,
2901            Self::NotAnswered => Self::Enabled,
2902        }
2903    }
2904}
2905
2906impl From<bool> for DataCollectionChoice {
2907    fn from(value: bool) -> Self {
2908        match value {
2909            true => DataCollectionChoice::Enabled,
2910            false => DataCollectionChoice::Disabled,
2911        }
2912    }
2913}
2914
2915struct ZedPredictUpsell;
2916
2917impl Dismissable for ZedPredictUpsell {
2918    const KEY: &'static str = "dismissed-edit-predict-upsell";
2919
2920    fn dismissed(cx: &App) -> bool {
2921        // To make this backwards compatible with older versions of Zed, we
2922        // check if the user has seen the previous Edit Prediction Onboarding
2923        // before, by checking the data collection choice which was written to
2924        // the database once the user clicked on "Accept and Enable"
2925        let kvp = KeyValueStore::global(cx);
2926        if kvp
2927            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2928            .log_err()
2929            .is_some_and(|s| s.is_some())
2930        {
2931            return true;
2932        }
2933
2934        kvp.read_kvp(Self::KEY)
2935            .log_err()
2936            .is_some_and(|s| s.is_some())
2937    }
2938}
2939
2940pub fn should_show_upsell_modal(cx: &App) -> bool {
2941    !ZedPredictUpsell::dismissed(cx)
2942}
2943
2944pub fn init(cx: &mut App) {
2945    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2946        workspace.register_action(
2947            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2948                ZedPredictModal::toggle(
2949                    workspace,
2950                    workspace.user_store().clone(),
2951                    workspace.client().clone(),
2952                    window,
2953                    cx,
2954                )
2955            },
2956        );
2957
2958        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2959            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2960                settings
2961                    .project
2962                    .all_languages
2963                    .edit_predictions
2964                    .get_or_insert_default()
2965                    .provider = Some(EditPredictionProvider::None)
2966            });
2967        });
2968        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2969            EditPredictionStore::try_global(cx).and_then(|store| {
2970                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2971            })
2972        }
2973
2974        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2975            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2976                copilot_ui::initiate_sign_in(copilot, window, cx);
2977            }
2978        });
2979        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2980            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2981                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2982            }
2983        });
2984        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2985            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2986                copilot_ui::initiate_sign_out(copilot, window, cx);
2987            }
2988        });
2989    })
2990    .detach();
2991}