edit_prediction.rs

   1use anyhow::Result;
   2use client::{
   3    Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore,
   4    global_llm_token as global_llm_api_token,
   5};
   6use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   7use cloud_llm_client::predict_edits_v3::{
   8    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   9};
  10use cloud_llm_client::{
  11    EditPredictionRejectReason, EditPredictionRejection,
  12    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  13    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  14};
  15use collections::{HashMap, HashSet};
  16use copilot::{Copilot, Reinstall, SignIn, SignOut};
  17use credentials_provider::CredentialsProvider;
  18use db::kvp::{Dismissable, KeyValueStore};
  19use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  21use futures::{
  22    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  23    channel::mpsc::{self, UnboundedReceiver},
  24    select_biased,
  25};
  26use gpui::BackgroundExecutor;
  27use gpui::http_client::Url;
  28use gpui::{
  29    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  30    http_client::{self, AsyncBody, Method},
  31    prelude::*,
  32};
  33use heapless::Vec as ArrayVec;
  34use language::language_settings::all_language_settings;
  35use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  36use language::{BufferSnapshot, OffsetRangeExt};
  37use language_model::LlmApiToken;
  38use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  39use release_channel::AppVersion;
  40use semver::Version;
  41use serde::de::DeserializeOwned;
  42use settings::{
  43    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  44};
  45use std::collections::{VecDeque, hash_map};
  46use std::env;
  47use text::{AnchorRangeExt, Edit};
  48use workspace::{AppState, Workspace};
  49use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  50
  51use std::mem;
  52use std::ops::Range;
  53use std::path::Path;
  54use std::rc::Rc;
  55use std::str::FromStr as _;
  56use std::sync::Arc;
  57use std::time::{Duration, Instant};
  58
  59use thiserror::Error;
  60use util::{RangeExt as _, ResultExt as _};
  61
  62pub mod cursor_excerpt;
  63pub mod example_spec;
  64pub mod fim;
  65mod license_detection;
  66pub mod mercury;
  67pub mod ollama;
  68mod onboarding_modal;
  69pub mod open_ai_response;
  70mod prediction;
  71
  72pub mod udiff;
  73
  74mod capture_example;
  75pub mod open_ai_compatible;
  76mod zed_edit_prediction_delegate;
  77pub mod zeta;
  78
  79#[cfg(test)]
  80mod edit_prediction_tests;
  81
  82use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  83use crate::example_spec::ExampleSpec;
  84use crate::license_detection::LicenseDetectionWatcher;
  85use crate::mercury::Mercury;
  86use crate::onboarding_modal::ZedPredictModal;
  87pub use crate::prediction::EditPrediction;
  88pub use crate::prediction::EditPredictionId;
  89use crate::prediction::EditPredictionResult;
  90pub use capture_example::capture_example;
  91pub use language_model::ApiKeyState;
  92pub use telemetry_events::EditPredictionRating;
  93pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  94
  95actions!(
  96    edit_prediction,
  97    [
  98        /// Resets the edit prediction onboarding state.
  99        ResetOnboarding,
 100        /// Clears the edit prediction history.
 101        ClearHistory,
 102    ]
 103);
 104
 105/// Maximum number of events to track.
 106const EVENT_COUNT_MAX: usize = 10;
 107const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 108const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 109const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 110const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 111const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 112const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 113const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 114const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 115const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 116
 117pub struct EditPredictionJumpsFeatureFlag;
 118
 119impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 120    const NAME: &'static str = "edit_prediction_jumps";
 121}
 122
 123#[derive(Clone)]
 124struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 125
 126impl Global for EditPredictionStoreGlobal {}
 127
 128/// Configuration for using the raw Zeta2 endpoint.
 129/// When set, the client uses the raw endpoint and constructs the prompt itself.
 130/// The version is also used as the Baseten environment name (lowercased).
 131#[derive(Clone)]
 132pub struct Zeta2RawConfig {
 133    pub model_id: Option<String>,
 134    pub environment: Option<String>,
 135    pub format: ZetaFormat,
 136}
 137
 138pub struct EditPredictionStore {
 139    client: Arc<Client>,
 140    user_store: Entity<UserStore>,
 141    llm_token: LlmApiToken,
 142    _fetch_experiments_task: Task<()>,
 143    projects: HashMap<EntityId, ProjectState>,
 144    update_required: bool,
 145    edit_prediction_model: EditPredictionModel,
 146    zeta2_raw_config: Option<Zeta2RawConfig>,
 147    preferred_experiment: Option<String>,
 148    available_experiments: Vec<String>,
 149    pub mercury: Mercury,
 150    data_collection_choice: DataCollectionChoice,
 151    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 152    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 153    shown_predictions: VecDeque<EditPrediction>,
 154    rated_predictions: HashSet<EditPredictionId>,
 155    #[cfg(test)]
 156    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 157    credentials_provider: Arc<dyn CredentialsProvider>,
 158}
 159
 160pub(crate) struct EditPredictionRejectionPayload {
 161    rejection: EditPredictionRejection,
 162    organization_id: Option<OrganizationId>,
 163}
 164
 165#[derive(Copy, Clone, PartialEq, Eq)]
 166pub enum EditPredictionModel {
 167    Zeta,
 168    Fim { format: EditPredictionPromptFormat },
 169    Mercury,
 170}
 171
 172#[derive(Clone)]
 173pub struct EditPredictionModelInput {
 174    project: Entity<Project>,
 175    buffer: Entity<Buffer>,
 176    snapshot: BufferSnapshot,
 177    position: Anchor,
 178    events: Vec<Arc<zeta_prompt::Event>>,
 179    related_files: Vec<RelatedFile>,
 180    trigger: PredictEditsRequestTrigger,
 181    diagnostic_search_range: Range<Point>,
 182    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 183    can_collect_data: bool,
 184    is_open_source: bool,
 185}
 186
 187#[derive(Debug)]
 188pub enum DebugEvent {
 189    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 190    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 191    EditPredictionStarted(EditPredictionStartedDebugEvent),
 192    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 193}
 194
 195#[derive(Debug)]
 196pub struct ContextRetrievalStartedDebugEvent {
 197    pub project_entity_id: EntityId,
 198    pub timestamp: Instant,
 199    pub search_prompt: String,
 200}
 201
 202#[derive(Debug)]
 203pub struct ContextRetrievalFinishedDebugEvent {
 204    pub project_entity_id: EntityId,
 205    pub timestamp: Instant,
 206    pub metadata: Vec<(&'static str, SharedString)>,
 207}
 208
 209#[derive(Debug)]
 210pub struct EditPredictionStartedDebugEvent {
 211    pub buffer: WeakEntity<Buffer>,
 212    pub position: Anchor,
 213    pub prompt: Option<String>,
 214}
 215
 216#[derive(Debug)]
 217pub struct EditPredictionFinishedDebugEvent {
 218    pub buffer: WeakEntity<Buffer>,
 219    pub position: Anchor,
 220    pub model_output: Option<String>,
 221}
 222
 223/// An event with associated metadata for reconstructing buffer state.
 224#[derive(Clone)]
 225pub struct StoredEvent {
 226    pub event: Arc<zeta_prompt::Event>,
 227    pub old_snapshot: TextBufferSnapshot,
 228    pub new_snapshot_version: clock::Global,
 229    pub total_edit_range: Range<Anchor>,
 230}
 231
 232impl StoredEvent {
 233    fn can_merge(
 234        &self,
 235        next_old_event: &StoredEvent,
 236        latest_snapshot: &TextBufferSnapshot,
 237        latest_edit_range: &Range<Anchor>,
 238    ) -> bool {
 239        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 240        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 241            return false;
 242        }
 243        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 244            return false;
 245        }
 246        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 247            return false;
 248        }
 249        if !latest_snapshot
 250            .version
 251            .observed_all(&next_old_event.new_snapshot_version)
 252        {
 253            return false;
 254        }
 255
 256        let a_is_predicted = matches!(
 257            self.event.as_ref(),
 258            zeta_prompt::Event::BufferChange {
 259                predicted: true,
 260                ..
 261            }
 262        );
 263        let b_is_predicted = matches!(
 264            next_old_event.event.as_ref(),
 265            zeta_prompt::Event::BufferChange {
 266                predicted: true,
 267                ..
 268            }
 269        );
 270
 271        // If events come from the same source (both predicted or both manual) then
 272        // we would have coalesced them already.
 273        if a_is_predicted == b_is_predicted {
 274            return false;
 275        }
 276
 277        let left_range = self.total_edit_range.to_point(latest_snapshot);
 278        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 279        let latest_range = latest_edit_range.to_point(latest_snapshot);
 280
 281        // Events near to the latest edit are not merged if their sources differ.
 282        if lines_between_ranges(&left_range, &latest_range)
 283            .min(lines_between_ranges(&right_range, &latest_range))
 284            <= CHANGE_GROUPING_LINE_SPAN
 285        {
 286            return false;
 287        }
 288
 289        // Events that are distant from each other are not merged.
 290        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 291            return false;
 292        }
 293
 294        true
 295    }
 296}
 297
 298fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 299    if left.start > right.end {
 300        return left.start.row - right.end.row;
 301    }
 302    if right.start > left.end {
 303        return right.start.row - left.end.row;
 304    }
 305    0
 306}
 307
 308struct ProjectState {
 309    events: VecDeque<StoredEvent>,
 310    last_event: Option<LastEvent>,
 311    recent_paths: VecDeque<ProjectPath>,
 312    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 313    current_prediction: Option<CurrentEditPrediction>,
 314    next_pending_prediction_id: usize,
 315    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 316    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 317    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 318    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 319    cancelled_predictions: HashSet<usize>,
 320    context: Entity<RelatedExcerptStore>,
 321    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 322    _subscriptions: [gpui::Subscription; 2],
 323    copilot: Option<Entity<Copilot>>,
 324}
 325
 326impl ProjectState {
 327    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 328        self.events
 329            .iter()
 330            .cloned()
 331            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 332                let (one, two) = event.split_by_pause();
 333                let one = one.finalize(&self.license_detection_watchers, cx);
 334                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 335                one.into_iter().chain(two)
 336            }))
 337            .collect()
 338    }
 339
 340    fn cancel_pending_prediction(
 341        &mut self,
 342        pending_prediction: PendingPrediction,
 343        cx: &mut Context<EditPredictionStore>,
 344    ) {
 345        self.cancelled_predictions.insert(pending_prediction.id);
 346
 347        if pending_prediction.drop_on_cancel {
 348            drop(pending_prediction.task);
 349        } else {
 350            cx.spawn(async move |this, cx| {
 351                let Some(prediction_id) = pending_prediction.task.await else {
 352                    return;
 353                };
 354
 355                this.update(cx, |this, cx| {
 356                    this.reject_prediction(
 357                        prediction_id,
 358                        EditPredictionRejectReason::Canceled,
 359                        false,
 360                        None,
 361                        None,
 362                        cx,
 363                    );
 364                })
 365                .ok();
 366            })
 367            .detach()
 368        }
 369    }
 370
 371    fn active_buffer(
 372        &self,
 373        project: &Entity<Project>,
 374        cx: &App,
 375    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 376        let project = project.read(cx);
 377        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 378        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 379        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 380        Some((active_buffer, registered_buffer.last_position))
 381    }
 382}
 383
 384#[derive(Debug, Clone)]
 385struct CurrentEditPrediction {
 386    pub requested_by: PredictionRequestedBy,
 387    pub prediction: EditPrediction,
 388    pub was_shown: bool,
 389    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 390    pub e2e_latency: std::time::Duration,
 391}
 392
 393impl CurrentEditPrediction {
 394    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 395        let Some(new_edits) = self
 396            .prediction
 397            .interpolate(&self.prediction.buffer.read(cx))
 398        else {
 399            return false;
 400        };
 401
 402        if self.prediction.buffer != old_prediction.prediction.buffer {
 403            return true;
 404        }
 405
 406        let Some(old_edits) = old_prediction
 407            .prediction
 408            .interpolate(&old_prediction.prediction.buffer.read(cx))
 409        else {
 410            return true;
 411        };
 412
 413        let requested_by_buffer_id = self.requested_by.buffer_id();
 414
 415        // This reduces the occurrence of UI thrash from replacing edits
 416        //
 417        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 418        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 419            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 420            && old_edits.len() == 1
 421            && new_edits.len() == 1
 422        {
 423            let (old_range, old_text) = &old_edits[0];
 424            let (new_range, new_text) = &new_edits[0];
 425            new_range == old_range && new_text.starts_with(old_text.as_ref())
 426        } else {
 427            true
 428        }
 429    }
 430}
 431
 432#[derive(Debug, Clone)]
 433enum PredictionRequestedBy {
 434    DiagnosticsUpdate,
 435    Buffer(EntityId),
 436}
 437
 438impl PredictionRequestedBy {
 439    pub fn buffer_id(&self) -> Option<EntityId> {
 440        match self {
 441            PredictionRequestedBy::DiagnosticsUpdate => None,
 442            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 443        }
 444    }
 445}
 446
 447const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 448
 449#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 450pub enum DiagnosticSearchScope {
 451    Local,
 452    Global,
 453}
 454
 455#[derive(Debug)]
 456struct PendingPrediction {
 457    id: usize,
 458    task: Task<Option<EditPredictionId>>,
 459    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 460    /// If false, the task is awaited to completion so rejection can be reported.
 461    drop_on_cancel: bool,
 462}
 463
 464/// A prediction from the perspective of a buffer.
 465#[derive(Debug)]
 466enum BufferEditPrediction<'a> {
 467    Local { prediction: &'a EditPrediction },
 468    Jump { prediction: &'a EditPrediction },
 469}
 470
 471#[cfg(test)]
 472impl std::ops::Deref for BufferEditPrediction<'_> {
 473    type Target = EditPrediction;
 474
 475    fn deref(&self) -> &Self::Target {
 476        match self {
 477            BufferEditPrediction::Local { prediction } => prediction,
 478            BufferEditPrediction::Jump { prediction } => prediction,
 479        }
 480    }
 481}
 482
 483#[derive(Clone)]
 484
 485struct PendingSettledPrediction {
 486    request_id: EditPredictionId,
 487    editable_anchor_range: Range<Anchor>,
 488    example: Option<ExampleSpec>,
 489    enqueued_at: Instant,
 490    last_edit_at: Instant,
 491    e2e_latency: std::time::Duration,
 492}
 493
 494struct RegisteredBuffer {
 495    file: Option<Arc<dyn File>>,
 496    snapshot: TextBufferSnapshot,
 497    pending_predictions: Vec<PendingSettledPrediction>,
 498    last_position: Option<Anchor>,
 499    _subscriptions: [gpui::Subscription; 2],
 500}
 501
 502#[derive(Clone)]
 503struct LastEvent {
 504    old_snapshot: TextBufferSnapshot,
 505    new_snapshot: TextBufferSnapshot,
 506    old_file: Option<Arc<dyn File>>,
 507    new_file: Option<Arc<dyn File>>,
 508    latest_edit_range: Range<Anchor>,
 509    total_edit_range: Range<Anchor>,
 510    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 511    predicted: bool,
 512    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 513    last_edit_time: Option<Instant>,
 514}
 515
 516impl LastEvent {
 517    pub fn finalize(
 518        &self,
 519        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 520        cx: &App,
 521    ) -> Option<StoredEvent> {
 522        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 523        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 524
 525        let in_open_source_repo =
 526            [self.new_file.as_ref(), self.old_file.as_ref()]
 527                .iter()
 528                .all(|file| {
 529                    file.is_some_and(|file| {
 530                        license_detection_watchers
 531                            .get(&file.worktree_id(cx))
 532                            .is_some_and(|watcher| watcher.is_project_open_source())
 533                    })
 534                });
 535
 536        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 537            &self.old_snapshot,
 538            &self.new_snapshot,
 539            &self.total_edit_range,
 540        )?;
 541
 542        if path == old_path && diff.is_empty() {
 543            None
 544        } else {
 545            Some(StoredEvent {
 546                event: Arc::new(zeta_prompt::Event::BufferChange {
 547                    old_path,
 548                    path,
 549                    diff,
 550                    in_open_source_repo,
 551                    predicted: self.predicted,
 552                }),
 553                old_snapshot: self.old_snapshot.clone(),
 554                new_snapshot_version: self.new_snapshot.version.clone(),
 555                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 556                    ..self.new_snapshot.anchor_before(edit_range.end),
 557            })
 558        }
 559    }
 560
 561    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 562        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 563            return (self.clone(), None);
 564        };
 565
 566        let total_edit_range_before_pause = self
 567            .total_edit_range_at_last_pause_boundary
 568            .clone()
 569            .unwrap_or_else(|| self.total_edit_range.clone());
 570
 571        let Some(total_edit_range_after_pause) =
 572            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 573        else {
 574            return (self.clone(), None);
 575        };
 576
 577        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 578        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 579
 580        let before = LastEvent {
 581            old_snapshot: self.old_snapshot.clone(),
 582            new_snapshot: boundary_snapshot.clone(),
 583            old_file: self.old_file.clone(),
 584            new_file: self.new_file.clone(),
 585            latest_edit_range: latest_edit_range_before_pause,
 586            total_edit_range: total_edit_range_before_pause,
 587            total_edit_range_at_last_pause_boundary: None,
 588            predicted: self.predicted,
 589            snapshot_after_last_editing_pause: None,
 590            last_edit_time: self.last_edit_time,
 591        };
 592
 593        let after = LastEvent {
 594            old_snapshot: boundary_snapshot.clone(),
 595            new_snapshot: self.new_snapshot.clone(),
 596            old_file: self.old_file.clone(),
 597            new_file: self.new_file.clone(),
 598            latest_edit_range: latest_edit_range_after_pause,
 599            total_edit_range: total_edit_range_after_pause,
 600            total_edit_range_at_last_pause_boundary: None,
 601            predicted: self.predicted,
 602            snapshot_after_last_editing_pause: None,
 603            last_edit_time: self.last_edit_time,
 604        };
 605
 606        (before, Some(after))
 607    }
 608}
 609
 610fn compute_total_edit_range_between_snapshots(
 611    old_snapshot: &TextBufferSnapshot,
 612    new_snapshot: &TextBufferSnapshot,
 613) -> Option<Range<Anchor>> {
 614    let edits: Vec<Edit<usize>> = new_snapshot
 615        .edits_since::<usize>(&old_snapshot.version)
 616        .collect();
 617
 618    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 619    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 620    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 621
 622    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 623}
 624
 625fn compute_old_range_for_new_range(
 626    old_snapshot: &TextBufferSnapshot,
 627    new_snapshot: &TextBufferSnapshot,
 628    total_edit_range: &Range<Anchor>,
 629) -> Option<Range<Point>> {
 630    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 631    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 632
 633    let edits: Vec<Edit<usize>> = new_snapshot
 634        .edits_since::<usize>(&old_snapshot.version)
 635        .collect();
 636    let mut old_start_offset = None;
 637    let mut old_end_offset = None;
 638    let mut delta: isize = 0;
 639
 640    for edit in &edits {
 641        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 642            old_start_offset = Some(if new_start_offset < edit.new.start {
 643                new_start_offset.checked_add_signed(-delta)?
 644            } else {
 645                edit.old.start
 646            });
 647        }
 648
 649        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 650            old_end_offset = Some(if new_end_offset < edit.new.start {
 651                new_end_offset.checked_add_signed(-delta)?
 652            } else {
 653                edit.old.end
 654            });
 655        }
 656
 657        delta += edit.new.len() as isize - edit.old.len() as isize;
 658    }
 659
 660    let old_start_offset =
 661        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 662    let old_end_offset =
 663        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 664
 665    Some(
 666        old_snapshot.offset_to_point(old_start_offset)
 667            ..old_snapshot.offset_to_point(old_end_offset),
 668    )
 669}
 670
 671fn compute_diff_between_snapshots_in_range(
 672    old_snapshot: &TextBufferSnapshot,
 673    new_snapshot: &TextBufferSnapshot,
 674    total_edit_range: &Range<Anchor>,
 675) -> Option<(String, Range<Point>)> {
 676    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 677    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 678    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 679    let old_start_point = old_range.start;
 680    let old_end_point = old_range.end;
 681
 682    const CONTEXT_LINES: u32 = 3;
 683
 684    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 685    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 686    let old_context_end_row =
 687        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 688    let new_context_end_row =
 689        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 690
 691    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 692    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 693    let old_end_line_offset = old_snapshot
 694        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 695    let new_end_line_offset = new_snapshot
 696        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 697    let old_edit_range = old_start_line_offset..old_end_line_offset;
 698    let new_edit_range = new_start_line_offset..new_end_line_offset;
 699
 700    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 701        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 702    {
 703        return None;
 704    }
 705
 706    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 707    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 708
 709    let diff = language::unified_diff_with_offsets(
 710        &old_region_text,
 711        &new_region_text,
 712        old_context_start_row,
 713        new_context_start_row,
 714    );
 715
 716    Some((diff, new_start_point..new_end_point))
 717}
 718
 719fn buffer_path_with_id_fallback(
 720    file: Option<&Arc<dyn File>>,
 721    snapshot: &TextBufferSnapshot,
 722    cx: &App,
 723) -> Arc<Path> {
 724    if let Some(file) = file {
 725        file.full_path(cx).into()
 726    } else {
 727        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 728    }
 729}
 730
 731impl EditPredictionStore {
 732    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 733        cx.try_global::<EditPredictionStoreGlobal>()
 734            .map(|global| global.0.clone())
 735    }
 736
 737    pub fn global(
 738        client: &Arc<Client>,
 739        user_store: &Entity<UserStore>,
 740        cx: &mut App,
 741    ) -> Entity<Self> {
 742        cx.try_global::<EditPredictionStoreGlobal>()
 743            .map(|global| global.0.clone())
 744            .unwrap_or_else(|| {
 745                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 746                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 747                ep_store
 748            })
 749    }
 750
 751    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 752        let data_collection_choice = Self::load_data_collection_choice(cx);
 753
 754        let llm_token = global_llm_api_token(cx);
 755
 756        let (reject_tx, reject_rx) = mpsc::unbounded();
 757        cx.background_spawn({
 758            let client = client.clone();
 759            let llm_token = llm_token.clone();
 760            let app_version = AppVersion::global(cx);
 761            let background_executor = cx.background_executor().clone();
 762            async move {
 763                Self::handle_rejected_predictions(
 764                    reject_rx,
 765                    client,
 766                    llm_token,
 767                    app_version,
 768                    background_executor,
 769                )
 770                .await
 771            }
 772        })
 773        .detach();
 774
 775        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 776        cx.spawn(async move |this, cx| {
 777            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 778        })
 779        .detach();
 780
 781        let mut current_user = user_store.read(cx).watch_current_user();
 782        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 783            while current_user.borrow().is_none() {
 784                current_user.next().await;
 785            }
 786
 787            this.update(cx, |this, cx| {
 788                if cx.is_staff() {
 789                    this.refresh_available_experiments(cx);
 790                }
 791            })
 792            .log_err();
 793        });
 794
 795        let credentials_provider = zed_credentials_provider::global(cx);
 796
 797        let this = Self {
 798            projects: HashMap::default(),
 799            client,
 800            user_store,
 801            llm_token,
 802            _fetch_experiments_task: fetch_experiments_task,
 803            update_required: false,
 804            edit_prediction_model: EditPredictionModel::Zeta,
 805            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 806            preferred_experiment: None,
 807            available_experiments: Vec::new(),
 808            mercury: Mercury::new(cx),
 809
 810            data_collection_choice,
 811            reject_predictions_tx: reject_tx,
 812            settled_predictions_tx,
 813            rated_predictions: Default::default(),
 814            shown_predictions: Default::default(),
 815            #[cfg(test)]
 816            settled_event_callback: None,
 817
 818            credentials_provider,
 819        };
 820
 821        this
 822    }
 823
 824    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 825        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 826        let format = ZetaFormat::parse(&version_str).ok()?;
 827        let model_id = env::var("ZED_ZETA_MODEL").ok();
 828        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 829        Some(Zeta2RawConfig {
 830            model_id,
 831            environment,
 832            format,
 833        })
 834    }
 835
 836    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 837        self.edit_prediction_model = model;
 838    }
 839
 840    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 841        self.zeta2_raw_config = Some(config);
 842    }
 843
 844    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 845        self.zeta2_raw_config.as_ref()
 846    }
 847
 848    pub fn preferred_experiment(&self) -> Option<&str> {
 849        self.preferred_experiment.as_deref()
 850    }
 851
 852    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 853        self.preferred_experiment = experiment;
 854    }
 855
 856    pub fn available_experiments(&self) -> &[String] {
 857        &self.available_experiments
 858    }
 859
 860    pub fn active_experiment(&self) -> Option<&str> {
 861        self.preferred_experiment.as_deref().or_else(|| {
 862            self.shown_predictions
 863                .iter()
 864                .find_map(|p| p.model_version.as_ref())
 865                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 866        })
 867    }
 868
 869    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 870        let client = self.client.clone();
 871        let llm_token = self.llm_token.clone();
 872        let app_version = AppVersion::global(cx);
 873        let organization_id = self
 874            .user_store
 875            .read(cx)
 876            .current_organization()
 877            .map(|organization| organization.id.clone());
 878
 879        cx.spawn(async move |this, cx| {
 880            let experiments = cx
 881                .background_spawn(async move {
 882                    let http_client = client.http_client();
 883                    let token = client
 884                        .acquire_llm_token(&llm_token, organization_id.clone())
 885                        .await?;
 886                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 887                    let request = http_client::Request::builder()
 888                        .method(Method::GET)
 889                        .uri(url.as_ref())
 890                        .header("Authorization", format!("Bearer {}", token))
 891                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 892                        .body(Default::default())?;
 893                    let mut response = http_client.send(request).await?;
 894                    if response.status().is_success() {
 895                        let mut body = Vec::new();
 896                        response.body_mut().read_to_end(&mut body).await?;
 897                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 898                        Ok(experiments)
 899                    } else {
 900                        let mut body = String::new();
 901                        response.body_mut().read_to_string(&mut body).await?;
 902                        anyhow::bail!(
 903                            "Failed to fetch experiments: {:?}\nBody: {}",
 904                            response.status(),
 905                            body
 906                        );
 907                    }
 908                })
 909                .await?;
 910            this.update(cx, |this, cx| {
 911                this.available_experiments = experiments;
 912                cx.notify();
 913            })?;
 914            anyhow::Ok(())
 915        })
 916        .detach_and_log_err(cx);
 917    }
 918
 919    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 920        use ui::IconName;
 921        match self.edit_prediction_model {
 922            EditPredictionModel::Mercury => {
 923                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 924            }
 925            EditPredictionModel::Zeta => {
 926                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 927                    .with_disabled(IconName::ZedPredictDisabled)
 928                    .with_up(IconName::ZedPredictUp)
 929                    .with_down(IconName::ZedPredictDown)
 930                    .with_error(IconName::ZedPredictError)
 931            }
 932            EditPredictionModel::Fim { .. } => {
 933                let settings = &all_language_settings(None, cx).edit_predictions;
 934                match settings.provider {
 935                    EditPredictionProvider::Ollama => {
 936                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 937                    }
 938                    _ => {
 939                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 940                    }
 941                }
 942            }
 943        }
 944    }
 945
 946    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 947        self.mercury.api_token.read(cx).has_key()
 948    }
 949
 950    pub fn mercury_has_payment_required_error(&self) -> bool {
 951        self.mercury.has_payment_required_error()
 952    }
 953
 954    pub fn clear_history(&mut self) {
 955        for project_state in self.projects.values_mut() {
 956            project_state.events.clear();
 957            project_state.last_event.take();
 958        }
 959    }
 960
 961    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 962        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 963            project_state.events.clear();
 964            project_state.last_event.take();
 965        }
 966    }
 967
 968    pub fn edit_history_for_project(
 969        &self,
 970        project: &Entity<Project>,
 971        cx: &App,
 972    ) -> Vec<StoredEvent> {
 973        self.projects
 974            .get(&project.entity_id())
 975            .map(|project_state| project_state.events(cx))
 976            .unwrap_or_default()
 977    }
 978
 979    pub fn context_for_project<'a>(
 980        &'a self,
 981        project: &Entity<Project>,
 982        cx: &'a mut App,
 983    ) -> Vec<RelatedFile> {
 984        self.projects
 985            .get(&project.entity_id())
 986            .map(|project_state| {
 987                project_state.context.update(cx, |context, cx| {
 988                    context
 989                        .related_files_with_buffers(cx)
 990                        .map(|(mut related_file, buffer)| {
 991                            related_file.in_open_source_repo = buffer
 992                                .read(cx)
 993                                .file()
 994                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 995                            related_file
 996                        })
 997                        .collect()
 998                })
 999            })
1000            .unwrap_or_default()
1001    }
1002
1003    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1004        self.projects
1005            .get(&project.entity_id())
1006            .and_then(|project| project.copilot.clone())
1007    }
1008
1009    pub fn start_copilot_for_project(
1010        &mut self,
1011        project: &Entity<Project>,
1012        cx: &mut Context<Self>,
1013    ) -> Option<Entity<Copilot>> {
1014        if DisableAiSettings::get(None, cx).disable_ai {
1015            return None;
1016        }
1017        let state = self.get_or_init_project(project, cx);
1018
1019        if state.copilot.is_some() {
1020            return state.copilot.clone();
1021        }
1022        let _project = project.clone();
1023        let project = project.read(cx);
1024
1025        let node = project.node_runtime().cloned();
1026        if let Some(node) = node {
1027            let next_id = project.languages().next_language_server_id();
1028            let fs = project.fs().clone();
1029
1030            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1031            state.copilot = Some(copilot.clone());
1032            Some(copilot)
1033        } else {
1034            None
1035        }
1036    }
1037
1038    pub fn context_for_project_with_buffers<'a>(
1039        &'a self,
1040        project: &Entity<Project>,
1041        cx: &'a mut App,
1042    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1043        self.projects
1044            .get(&project.entity_id())
1045            .map(|project| {
1046                project.context.update(cx, |context, cx| {
1047                    context.related_files_with_buffers(cx).collect()
1048                })
1049            })
1050            .unwrap_or_default()
1051    }
1052
1053    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1054        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1055            self.user_store.read(cx).edit_prediction_usage()
1056        } else {
1057            None
1058        }
1059    }
1060
1061    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1062        self.get_or_init_project(project, cx);
1063    }
1064
1065    pub fn register_buffer(
1066        &mut self,
1067        buffer: &Entity<Buffer>,
1068        project: &Entity<Project>,
1069        cx: &mut Context<Self>,
1070    ) {
1071        let project_state = self.get_or_init_project(project, cx);
1072        Self::register_buffer_impl(project_state, buffer, project, cx);
1073    }
1074
1075    fn get_or_init_project(
1076        &mut self,
1077        project: &Entity<Project>,
1078        cx: &mut Context<Self>,
1079    ) -> &mut ProjectState {
1080        let entity_id = project.entity_id();
1081        self.projects
1082            .entry(entity_id)
1083            .or_insert_with(|| ProjectState {
1084                context: {
1085                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1086                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1087                        this.handle_excerpt_store_event(entity_id, event);
1088                    })
1089                    .detach();
1090                    related_excerpt_store
1091                },
1092                events: VecDeque::new(),
1093                last_event: None,
1094                recent_paths: VecDeque::new(),
1095                debug_tx: None,
1096                registered_buffers: HashMap::default(),
1097                current_prediction: None,
1098                cancelled_predictions: HashSet::default(),
1099                pending_predictions: ArrayVec::new(),
1100                next_pending_prediction_id: 0,
1101                last_edit_prediction_refresh: None,
1102                last_jump_prediction_refresh: None,
1103                license_detection_watchers: HashMap::default(),
1104                _subscriptions: [
1105                    cx.subscribe(&project, Self::handle_project_event),
1106                    cx.observe_release(&project, move |this, _, cx| {
1107                        this.projects.remove(&entity_id);
1108                        cx.notify();
1109                    }),
1110                ],
1111                copilot: None,
1112            })
1113    }
1114
1115    pub fn remove_project(&mut self, project: &Entity<Project>) {
1116        self.projects.remove(&project.entity_id());
1117    }
1118
1119    fn handle_excerpt_store_event(
1120        &mut self,
1121        project_entity_id: EntityId,
1122        event: &RelatedExcerptStoreEvent,
1123    ) {
1124        if let Some(project_state) = self.projects.get(&project_entity_id) {
1125            if let Some(debug_tx) = project_state.debug_tx.clone() {
1126                match event {
1127                    RelatedExcerptStoreEvent::StartedRefresh => {
1128                        debug_tx
1129                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1130                                ContextRetrievalStartedDebugEvent {
1131                                    project_entity_id: project_entity_id,
1132                                    timestamp: Instant::now(),
1133                                    search_prompt: String::new(),
1134                                },
1135                            ))
1136                            .ok();
1137                    }
1138                    RelatedExcerptStoreEvent::FinishedRefresh {
1139                        cache_hit_count,
1140                        cache_miss_count,
1141                        mean_definition_latency,
1142                        max_definition_latency,
1143                    } => {
1144                        debug_tx
1145                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1146                                ContextRetrievalFinishedDebugEvent {
1147                                    project_entity_id: project_entity_id,
1148                                    timestamp: Instant::now(),
1149                                    metadata: vec![
1150                                        (
1151                                            "Cache Hits",
1152                                            format!(
1153                                                "{}/{}",
1154                                                cache_hit_count,
1155                                                cache_hit_count + cache_miss_count
1156                                            )
1157                                            .into(),
1158                                        ),
1159                                        (
1160                                            "Max LSP Time",
1161                                            format!("{} ms", max_definition_latency.as_millis())
1162                                                .into(),
1163                                        ),
1164                                        (
1165                                            "Mean LSP Time",
1166                                            format!("{} ms", mean_definition_latency.as_millis())
1167                                                .into(),
1168                                        ),
1169                                    ],
1170                                },
1171                            ))
1172                            .ok();
1173                    }
1174                }
1175            }
1176        }
1177    }
1178
1179    pub fn debug_info(
1180        &mut self,
1181        project: &Entity<Project>,
1182        cx: &mut Context<Self>,
1183    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1184        let project_state = self.get_or_init_project(project, cx);
1185        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1186        project_state.debug_tx = Some(debug_watch_tx);
1187        debug_watch_rx
1188    }
1189
1190    fn handle_project_event(
1191        &mut self,
1192        project: Entity<Project>,
1193        event: &project::Event,
1194        cx: &mut Context<Self>,
1195    ) {
1196        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1197            return;
1198        }
1199        // TODO [zeta2] init with recent paths
1200        match event {
1201            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1202                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1203                    return;
1204                };
1205                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1206                if let Some(path) = path {
1207                    if let Some(ix) = project_state
1208                        .recent_paths
1209                        .iter()
1210                        .position(|probe| probe == &path)
1211                    {
1212                        project_state.recent_paths.remove(ix);
1213                    }
1214                    project_state.recent_paths.push_front(path);
1215                }
1216            }
1217            project::Event::DiagnosticsUpdated { .. } => {
1218                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1219                    self.refresh_prediction_from_diagnostics(
1220                        project,
1221                        DiagnosticSearchScope::Global,
1222                        cx,
1223                    );
1224                }
1225            }
1226            _ => (),
1227        }
1228    }
1229
1230    fn register_buffer_impl<'a>(
1231        project_state: &'a mut ProjectState,
1232        buffer: &Entity<Buffer>,
1233        project: &Entity<Project>,
1234        cx: &mut Context<Self>,
1235    ) -> &'a mut RegisteredBuffer {
1236        let buffer_id = buffer.entity_id();
1237
1238        if let Some(file) = buffer.read(cx).file() {
1239            let worktree_id = file.worktree_id(cx);
1240            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1241                project_state
1242                    .license_detection_watchers
1243                    .entry(worktree_id)
1244                    .or_insert_with(|| {
1245                        let project_entity_id = project.entity_id();
1246                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1247                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1248                            else {
1249                                return;
1250                            };
1251                            project_state
1252                                .license_detection_watchers
1253                                .remove(&worktree_id);
1254                        })
1255                        .detach();
1256                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1257                    });
1258            }
1259        }
1260
1261        match project_state.registered_buffers.entry(buffer_id) {
1262            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1263            hash_map::Entry::Vacant(entry) => {
1264                let buf = buffer.read(cx);
1265                let snapshot = buf.text_snapshot();
1266                let file = buf.file().cloned();
1267                let project_entity_id = project.entity_id();
1268                entry.insert(RegisteredBuffer {
1269                    snapshot,
1270                    file,
1271                    last_position: None,
1272                    pending_predictions: Vec::new(),
1273                    _subscriptions: [
1274                        cx.subscribe(buffer, {
1275                            let project = project.downgrade();
1276                            move |this, buffer, event, cx| {
1277                                if let language::BufferEvent::Edited { is_local } = event
1278                                    && let Some(project) = project.upgrade()
1279                                {
1280                                    this.report_changes_for_buffer(
1281                                        &buffer, &project, false, *is_local, cx,
1282                                    );
1283                                }
1284                            }
1285                        }),
1286                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1287                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1288                            else {
1289                                return;
1290                            };
1291                            project_state.registered_buffers.remove(&buffer_id);
1292                        }),
1293                    ],
1294                })
1295            }
1296        }
1297    }
1298
1299    fn report_changes_for_buffer(
1300        &mut self,
1301        buffer: &Entity<Buffer>,
1302        project: &Entity<Project>,
1303        is_predicted: bool,
1304        is_local: bool,
1305        cx: &mut Context<Self>,
1306    ) {
1307        let project_state = self.get_or_init_project(project, cx);
1308        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1309
1310        let buf = buffer.read(cx);
1311        let new_file = buf.file().cloned();
1312        let new_snapshot = buf.text_snapshot();
1313        if new_snapshot.version == registered_buffer.snapshot.version {
1314            return;
1315        }
1316        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1317        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1318        let mut edit_range: Option<Range<Anchor>> = None;
1319        let now = cx.background_executor().now();
1320
1321        for (_edit, anchor_range) in
1322            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1323        {
1324            edit_range = Some(match edit_range {
1325                None => anchor_range,
1326                Some(acc) => acc.start..anchor_range.end,
1327            });
1328        }
1329
1330        let Some(edit_range) = edit_range else {
1331            return;
1332        };
1333
1334        for pending_prediction in &mut registered_buffer.pending_predictions {
1335            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1336                pending_prediction.last_edit_at = now;
1337            }
1338        }
1339
1340        let include_in_history = is_local
1341            || collaborator_edit_overlaps_locality_region(
1342                project_state,
1343                project,
1344                buffer,
1345                &buf.snapshot(),
1346                &edit_range,
1347                cx,
1348            );
1349
1350        if !include_in_history {
1351            return;
1352        }
1353
1354        let is_recordable_history_edit =
1355            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1356                .is_some();
1357
1358        let events = &mut project_state.events;
1359
1360        if !is_recordable_history_edit {
1361            if let Some(event) = project_state.last_event.take() {
1362                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1363                    if events.len() + 1 >= EVENT_COUNT_MAX {
1364                        events.pop_front();
1365                    }
1366                    events.push_back(event);
1367                }
1368            }
1369            return;
1370        }
1371
1372        if let Some(last_event) = project_state.last_event.as_mut() {
1373            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1374                == last_event.new_snapshot.remote_id()
1375                && old_snapshot.version == last_event.new_snapshot.version;
1376
1377            let prediction_source_changed = is_predicted != last_event.predicted;
1378
1379            let should_coalesce = is_next_snapshot_of_same_buffer
1380                && !prediction_source_changed
1381                && lines_between_ranges(
1382                    &edit_range.to_point(&new_snapshot),
1383                    &last_event.latest_edit_range.to_point(&new_snapshot),
1384                ) <= CHANGE_GROUPING_LINE_SPAN;
1385
1386            if should_coalesce {
1387                let pause_elapsed = last_event
1388                    .last_edit_time
1389                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1390                    .unwrap_or(false);
1391                if pause_elapsed {
1392                    last_event.snapshot_after_last_editing_pause =
1393                        Some(last_event.new_snapshot.clone());
1394                    last_event.total_edit_range_at_last_pause_boundary =
1395                        Some(last_event.total_edit_range.clone());
1396                }
1397
1398                last_event.latest_edit_range = edit_range.clone();
1399                last_event.total_edit_range =
1400                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1401                last_event.new_snapshot = new_snapshot;
1402                last_event.last_edit_time = Some(now);
1403                return;
1404            }
1405        }
1406
1407        if let Some(event) = project_state.last_event.take() {
1408            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1409                if events.len() + 1 >= EVENT_COUNT_MAX {
1410                    events.pop_front();
1411                }
1412                events.push_back(event);
1413            }
1414        }
1415
1416        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1417
1418        project_state.last_event = Some(LastEvent {
1419            old_file,
1420            new_file,
1421            old_snapshot,
1422            new_snapshot,
1423            latest_edit_range: edit_range.clone(),
1424            total_edit_range: edit_range,
1425            total_edit_range_at_last_pause_boundary: None,
1426            predicted: is_predicted,
1427            snapshot_after_last_editing_pause: None,
1428            last_edit_time: Some(now),
1429        });
1430    }
1431
1432    fn prediction_at(
1433        &mut self,
1434        buffer: &Entity<Buffer>,
1435        position: Option<language::Anchor>,
1436        project: &Entity<Project>,
1437        cx: &App,
1438    ) -> Option<BufferEditPrediction<'_>> {
1439        let project_state = self.projects.get_mut(&project.entity_id())?;
1440        if let Some(position) = position
1441            && let Some(buffer) = project_state
1442                .registered_buffers
1443                .get_mut(&buffer.entity_id())
1444        {
1445            buffer.last_position = Some(position);
1446        }
1447
1448        let CurrentEditPrediction {
1449            requested_by,
1450            prediction,
1451            ..
1452        } = project_state.current_prediction.as_ref()?;
1453
1454        if prediction.targets_buffer(buffer.read(cx)) {
1455            Some(BufferEditPrediction::Local { prediction })
1456        } else {
1457            let show_jump = match requested_by {
1458                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1459                    requested_by_buffer_id == &buffer.entity_id()
1460                }
1461                PredictionRequestedBy::DiagnosticsUpdate => true,
1462            };
1463
1464            if show_jump {
1465                Some(BufferEditPrediction::Jump { prediction })
1466            } else {
1467                None
1468            }
1469        }
1470    }
1471
1472    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1473        let Some(current_prediction) = self
1474            .projects
1475            .get_mut(&project.entity_id())
1476            .and_then(|project_state| project_state.current_prediction.take())
1477        else {
1478            return;
1479        };
1480
1481        self.report_changes_for_buffer(
1482            &current_prediction.prediction.buffer,
1483            project,
1484            true,
1485            true,
1486            cx,
1487        );
1488
1489        // can't hold &mut project_state ref across report_changes_for_buffer_call
1490        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1491            return;
1492        };
1493
1494        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1495            project_state.cancel_pending_prediction(pending_prediction, cx);
1496        }
1497
1498        match self.edit_prediction_model {
1499            EditPredictionModel::Mercury => {
1500                mercury::edit_prediction_accepted(
1501                    current_prediction.prediction.id,
1502                    self.client.http_client(),
1503                    cx,
1504                );
1505            }
1506            EditPredictionModel::Zeta => {
1507                let is_cloud = !matches!(
1508                    all_language_settings(None, cx).edit_predictions.provider,
1509                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1510                );
1511                if is_cloud {
1512                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1513                }
1514            }
1515            EditPredictionModel::Fim { .. } => {}
1516        }
1517    }
1518
1519    async fn handle_rejected_predictions(
1520        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1521        client: Arc<Client>,
1522        llm_token: LlmApiToken,
1523        app_version: Version,
1524        background_executor: BackgroundExecutor,
1525    ) {
1526        let mut rx = std::pin::pin!(rx.peekable());
1527        let mut batched = Vec::new();
1528
1529        while let Some(EditPredictionRejectionPayload {
1530            rejection,
1531            organization_id,
1532        }) = rx.next().await
1533        {
1534            batched.push(rejection);
1535
1536            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1537                select_biased! {
1538                    next = rx.as_mut().peek().fuse() => {
1539                        if next.is_some() {
1540                            continue;
1541                        }
1542                    }
1543                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1544                }
1545            }
1546
1547            let url = client
1548                .http_client()
1549                .build_zed_llm_url("/predict_edits/reject", &[])
1550                .unwrap();
1551
1552            let flush_count = batched
1553                .len()
1554                // in case items have accumulated after failure
1555                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1556            let start = batched.len() - flush_count;
1557
1558            let body = RejectEditPredictionsBodyRef {
1559                rejections: &batched[start..],
1560            };
1561
1562            let result = Self::send_api_request::<()>(
1563                |builder| {
1564                    let req = builder
1565                        .uri(url.as_ref())
1566                        .body(serde_json::to_string(&body)?.into());
1567                    anyhow::Ok(req?)
1568                },
1569                client.clone(),
1570                llm_token.clone(),
1571                organization_id,
1572                app_version.clone(),
1573                true,
1574            )
1575            .await;
1576
1577            if result.log_err().is_some() {
1578                batched.drain(start..);
1579            }
1580        }
1581    }
1582
1583    async fn run_settled_predictions_worker(
1584        this: WeakEntity<Self>,
1585        mut rx: UnboundedReceiver<Instant>,
1586        cx: &mut AsyncApp,
1587    ) {
1588        let mut next_wake_time: Option<Instant> = None;
1589        loop {
1590            let now = cx.background_executor().now();
1591            if let Some(wake_time) = next_wake_time.take() {
1592                cx.background_executor()
1593                    .timer(wake_time.duration_since(now))
1594                    .await;
1595            } else {
1596                let Some(new_enqueue_time) = rx.next().await else {
1597                    break;
1598                };
1599                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1600                while rx.next().now_or_never().flatten().is_some() {}
1601                continue;
1602            }
1603
1604            let Some(this) = this.upgrade() else {
1605                break;
1606            };
1607
1608            let now = cx.background_executor().now();
1609
1610            let mut oldest_edited_at = None;
1611
1612            this.update(cx, |this, _| {
1613                for (_, project_state) in this.projects.iter_mut() {
1614                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1615                        registered_buffer
1616                            .pending_predictions
1617                            .retain_mut(|pending_prediction| {
1618                                let age =
1619                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1620                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1621                                    return false;
1622                                }
1623
1624                                let quiet_for =
1625                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1626                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1627                                    let settled_editable_region = registered_buffer
1628                                        .snapshot
1629                                        .text_for_range(
1630                                            pending_prediction.editable_anchor_range.clone(),
1631                                        )
1632                                        .collect::<String>();
1633
1634                                    #[cfg(test)]
1635                                    if let Some(callback) = &this.settled_event_callback {
1636                                        callback(
1637                                            pending_prediction.request_id.clone(),
1638                                            settled_editable_region.clone(),
1639                                        );
1640                                    }
1641
1642                                    telemetry::event!(
1643                                        EDIT_PREDICTION_SETTLED_EVENT,
1644                                        request_id = pending_prediction.request_id.0.clone(),
1645                                        settled_editable_region,
1646                                        example = pending_prediction.example.take(),
1647                                        e2e_latency = pending_prediction.e2e_latency.as_millis(),
1648                                    );
1649
1650                                    return false;
1651                                }
1652
1653                                if oldest_edited_at
1654                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1655                                {
1656                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1657                                }
1658
1659                                true
1660                            });
1661                    }
1662                }
1663            });
1664
1665            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1666        }
1667    }
1668
1669    pub(crate) fn enqueue_settled_prediction(
1670        &mut self,
1671        request_id: EditPredictionId,
1672        project: &Entity<Project>,
1673        edited_buffer: &Entity<Buffer>,
1674        edited_buffer_snapshot: &BufferSnapshot,
1675        editable_offset_range: Range<usize>,
1676        example: Option<ExampleSpec>,
1677        e2e_latency: std::time::Duration,
1678        cx: &mut Context<Self>,
1679    ) {
1680        let this = &mut *self;
1681        let project_state = this.get_or_init_project(project, cx);
1682        if let Some(buffer) = project_state
1683            .registered_buffers
1684            .get_mut(&edited_buffer.entity_id())
1685        {
1686            let now = cx.background_executor().now();
1687            buffer.pending_predictions.push(PendingSettledPrediction {
1688                request_id: request_id,
1689                editable_anchor_range: edited_buffer_snapshot
1690                    .anchor_range_around(editable_offset_range),
1691                example,
1692                e2e_latency,
1693                enqueued_at: now,
1694                last_edit_at: now,
1695            });
1696            this.settled_predictions_tx.unbounded_send(now).ok();
1697        }
1698    }
1699
1700    fn reject_current_prediction(
1701        &mut self,
1702        reason: EditPredictionRejectReason,
1703        project: &Entity<Project>,
1704        cx: &App,
1705    ) {
1706        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1707            project_state.pending_predictions.clear();
1708            if let Some(prediction) = project_state.current_prediction.take() {
1709                let model_version = prediction.prediction.model_version.clone();
1710                self.reject_prediction(
1711                    prediction.prediction.id,
1712                    reason,
1713                    prediction.was_shown,
1714                    model_version,
1715                    Some(prediction.e2e_latency),
1716                    cx,
1717                );
1718            }
1719        };
1720    }
1721
1722    fn did_show_current_prediction(
1723        &mut self,
1724        project: &Entity<Project>,
1725        display_type: edit_prediction_types::SuggestionDisplayType,
1726        _cx: &mut Context<Self>,
1727    ) {
1728        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1729            return;
1730        };
1731
1732        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1733            return;
1734        };
1735
1736        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1737        let previous_shown_with = current_prediction.shown_with;
1738
1739        if previous_shown_with.is_none() || !is_jump {
1740            current_prediction.shown_with = Some(display_type);
1741        }
1742
1743        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1744
1745        if is_first_non_jump_show {
1746            current_prediction.was_shown = true;
1747        }
1748
1749        if is_first_non_jump_show {
1750            self.shown_predictions
1751                .push_front(current_prediction.prediction.clone());
1752            if self.shown_predictions.len() > 50 {
1753                let completion = self.shown_predictions.pop_back().unwrap();
1754                self.rated_predictions.remove(&completion.id);
1755            }
1756        }
1757    }
1758
1759    fn reject_prediction(
1760        &mut self,
1761        prediction_id: EditPredictionId,
1762        reason: EditPredictionRejectReason,
1763        was_shown: bool,
1764        model_version: Option<String>,
1765        e2e_latency: Option<std::time::Duration>,
1766        cx: &App,
1767    ) {
1768        match self.edit_prediction_model {
1769            EditPredictionModel::Zeta => {
1770                let is_cloud = !matches!(
1771                    all_language_settings(None, cx).edit_predictions.provider,
1772                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1773                );
1774
1775                if is_cloud {
1776                    let organization_id = self
1777                        .user_store
1778                        .read(cx)
1779                        .current_organization()
1780                        .map(|organization| organization.id.clone());
1781
1782                    self.reject_predictions_tx
1783                        .unbounded_send(EditPredictionRejectionPayload {
1784                            rejection: EditPredictionRejection {
1785                                request_id: prediction_id.to_string(),
1786                                reason,
1787                                was_shown,
1788                                model_version,
1789                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1790                            },
1791                            organization_id,
1792                        })
1793                        .log_err();
1794                }
1795            }
1796            EditPredictionModel::Mercury => {
1797                mercury::edit_prediction_rejected(
1798                    prediction_id,
1799                    was_shown,
1800                    reason,
1801                    self.client.http_client(),
1802                    cx,
1803                );
1804            }
1805            EditPredictionModel::Fim { .. } => {}
1806        }
1807    }
1808
1809    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1810        self.projects
1811            .get(&project.entity_id())
1812            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1813    }
1814
1815    pub fn refresh_prediction_from_buffer(
1816        &mut self,
1817        project: Entity<Project>,
1818        buffer: Entity<Buffer>,
1819        position: language::Anchor,
1820        cx: &mut Context<Self>,
1821    ) {
1822        self.queue_prediction_refresh(
1823            project.clone(),
1824            PredictEditsRequestTrigger::Other,
1825            buffer.entity_id(),
1826            cx,
1827            move |this, cx| {
1828                let Some(request_task) = this
1829                    .update(cx, |this, cx| {
1830                        this.request_prediction(
1831                            &project,
1832                            &buffer,
1833                            position,
1834                            PredictEditsRequestTrigger::Other,
1835                            cx,
1836                        )
1837                    })
1838                    .log_err()
1839                else {
1840                    return Task::ready(anyhow::Ok(None));
1841                };
1842
1843                cx.spawn(async move |_cx| {
1844                    request_task.await.map(|prediction_result| {
1845                        prediction_result.map(|prediction_result| {
1846                            (
1847                                prediction_result,
1848                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1849                            )
1850                        })
1851                    })
1852                })
1853            },
1854        )
1855    }
1856
1857    pub fn refresh_prediction_from_diagnostics(
1858        &mut self,
1859        project: Entity<Project>,
1860        scope: DiagnosticSearchScope,
1861        cx: &mut Context<Self>,
1862    ) {
1863        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1864            return;
1865        }
1866
1867        if currently_following(&project, cx) {
1868            return;
1869        }
1870
1871        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1872            return;
1873        };
1874
1875        // Prefer predictions from buffer
1876        if project_state.current_prediction.is_some() {
1877            log::debug!(
1878                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1879            );
1880            return;
1881        }
1882
1883        self.queue_prediction_refresh(
1884            project.clone(),
1885            PredictEditsRequestTrigger::Diagnostics,
1886            project.entity_id(),
1887            cx,
1888            move |this, cx| {
1889                let Some((active_buffer, snapshot, cursor_point)) = this
1890                    .read_with(cx, |this, cx| {
1891                        let project_state = this.projects.get(&project.entity_id())?;
1892                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1893                        let snapshot = buffer.read(cx).snapshot();
1894
1895                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1896                            return None;
1897                        }
1898
1899                        let cursor_point = position
1900                            .map(|pos| pos.to_point(&snapshot))
1901                            .unwrap_or_default();
1902
1903                        Some((buffer, snapshot, cursor_point))
1904                    })
1905                    .log_err()
1906                    .flatten()
1907                else {
1908                    return Task::ready(anyhow::Ok(None));
1909                };
1910
1911                cx.spawn(async move |cx| {
1912                    let diagnostic_search_range = match scope {
1913                        DiagnosticSearchScope::Local => {
1914                            let diagnostic_search_start =
1915                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1916                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1917                            Point::new(diagnostic_search_start, 0)
1918                                ..Point::new(diagnostic_search_end, 0)
1919                        }
1920                        DiagnosticSearchScope::Global => Default::default(),
1921                    };
1922
1923                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1924                        active_buffer,
1925                        &snapshot,
1926                        diagnostic_search_range,
1927                        cursor_point,
1928                        &project,
1929                        cx,
1930                    )
1931                    .await?
1932                    else {
1933                        return anyhow::Ok(None);
1934                    };
1935
1936                    let Some(prediction_result) = this
1937                        .update(cx, |this, cx| {
1938                            this.request_prediction(
1939                                &project,
1940                                &jump_buffer,
1941                                jump_position,
1942                                PredictEditsRequestTrigger::Diagnostics,
1943                                cx,
1944                            )
1945                        })?
1946                        .await?
1947                    else {
1948                        return anyhow::Ok(None);
1949                    };
1950
1951                    this.update(cx, |this, cx| {
1952                        Some((
1953                            if this
1954                                .get_or_init_project(&project, cx)
1955                                .current_prediction
1956                                .is_none()
1957                            {
1958                                prediction_result
1959                            } else {
1960                                EditPredictionResult {
1961                                    id: prediction_result.id,
1962                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1963                                    e2e_latency: prediction_result.e2e_latency,
1964                                }
1965                            },
1966                            PredictionRequestedBy::DiagnosticsUpdate,
1967                        ))
1968                    })
1969                })
1970            },
1971        );
1972    }
1973
1974    fn predictions_enabled_at(
1975        snapshot: &BufferSnapshot,
1976        position: Option<language::Anchor>,
1977        cx: &App,
1978    ) -> bool {
1979        let file = snapshot.file();
1980        let all_settings = all_language_settings(file, cx);
1981        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1982            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1983        {
1984            return false;
1985        }
1986
1987        if let Some(last_position) = position {
1988            let settings = snapshot.settings_at(last_position, cx);
1989
1990            if !settings.edit_predictions_disabled_in.is_empty()
1991                && let Some(scope) = snapshot.language_scope_at(last_position)
1992                && let Some(scope_name) = scope.override_name()
1993                && settings
1994                    .edit_predictions_disabled_in
1995                    .iter()
1996                    .any(|s| s == scope_name)
1997            {
1998                return false;
1999            }
2000        }
2001
2002        true
2003    }
2004
2005    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2006}
2007
2008fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2009    let Some(app_state) = AppState::try_global(cx) else {
2010        return false;
2011    };
2012
2013    app_state
2014        .workspace_store
2015        .read(cx)
2016        .workspaces()
2017        .filter_map(|workspace| workspace.upgrade())
2018        .any(|workspace| {
2019            workspace.read(cx).project().entity_id() == project.entity_id()
2020                && workspace
2021                    .read(cx)
2022                    .leader_for_pane(workspace.read(cx).active_pane())
2023                    .is_some()
2024        })
2025}
2026
2027fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2028    match provider {
2029        EditPredictionProvider::Zed
2030        | EditPredictionProvider::Mercury
2031        | EditPredictionProvider::Ollama
2032        | EditPredictionProvider::OpenAiCompatibleApi
2033        | EditPredictionProvider::Experimental(_) => true,
2034        EditPredictionProvider::None
2035        | EditPredictionProvider::Copilot
2036        | EditPredictionProvider::Codestral => false,
2037    }
2038}
2039
2040impl EditPredictionStore {
2041    fn queue_prediction_refresh(
2042        &mut self,
2043        project: Entity<Project>,
2044        request_trigger: PredictEditsRequestTrigger,
2045        throttle_entity: EntityId,
2046        cx: &mut Context<Self>,
2047        do_refresh: impl FnOnce(
2048            WeakEntity<Self>,
2049            &mut AsyncApp,
2050        )
2051            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2052        + 'static,
2053    ) {
2054        fn select_throttle(
2055            project_state: &mut ProjectState,
2056            request_trigger: PredictEditsRequestTrigger,
2057        ) -> &mut Option<(EntityId, Instant)> {
2058            match request_trigger {
2059                PredictEditsRequestTrigger::Diagnostics => {
2060                    &mut project_state.last_jump_prediction_refresh
2061                }
2062                _ => &mut project_state.last_edit_prediction_refresh,
2063            }
2064        }
2065
2066        let (needs_acceptance_tracking, max_pending_predictions) =
2067            match all_language_settings(None, cx).edit_predictions.provider {
2068                EditPredictionProvider::Zed
2069                | EditPredictionProvider::Mercury
2070                | EditPredictionProvider::Experimental(_) => (true, 2),
2071                EditPredictionProvider::Ollama => (false, 1),
2072                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2073                EditPredictionProvider::None
2074                | EditPredictionProvider::Copilot
2075                | EditPredictionProvider::Codestral => {
2076                    log::error!("queue_prediction_refresh called with non-store provider");
2077                    return;
2078                }
2079            };
2080
2081        let drop_on_cancel = !needs_acceptance_tracking;
2082        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2083        let project_state = self.get_or_init_project(&project, cx);
2084        let pending_prediction_id = project_state.next_pending_prediction_id;
2085        project_state.next_pending_prediction_id += 1;
2086        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2087
2088        let task = cx.spawn(async move |this, cx| {
2089            let throttle_wait = this
2090                .update(cx, |this, cx| {
2091                    let project_state = this.get_or_init_project(&project, cx);
2092                    let throttle = *select_throttle(project_state, request_trigger);
2093
2094                    let now = cx.background_executor().now();
2095                    throttle.and_then(|(last_entity, last_timestamp)| {
2096                        if throttle_entity != last_entity {
2097                            return None;
2098                        }
2099                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2100                    })
2101                })
2102                .ok()
2103                .flatten();
2104
2105            if let Some(timeout) = throttle_wait {
2106                cx.background_executor().timer(timeout).await;
2107            }
2108
2109            // If this task was cancelled before the throttle timeout expired,
2110            // do not perform a request. Also skip if another task already
2111            // proceeded since we were enqueued (duplicate).
2112            let mut is_cancelled = true;
2113            this.update(cx, |this, cx| {
2114                let project_state = this.get_or_init_project(&project, cx);
2115                let was_cancelled = project_state
2116                    .cancelled_predictions
2117                    .remove(&pending_prediction_id);
2118                if was_cancelled {
2119                    return;
2120                }
2121
2122                // Another request has been already sent since this was enqueued
2123                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2124                    return;
2125                }
2126
2127                let new_refresh = (throttle_entity, cx.background_executor().now());
2128                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2129                is_cancelled = false;
2130            })
2131            .ok();
2132            if is_cancelled {
2133                return None;
2134            }
2135
2136            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2137            let new_prediction_id = new_prediction_result
2138                .as_ref()
2139                .map(|(prediction, _)| prediction.id.clone());
2140
2141            // When a prediction completes, remove it from the pending list, and cancel
2142            // any pending predictions that were enqueued before it.
2143            this.update(cx, |this, cx| {
2144                let project_state = this.get_or_init_project(&project, cx);
2145
2146                let is_cancelled = project_state
2147                    .cancelled_predictions
2148                    .remove(&pending_prediction_id);
2149
2150                let new_current_prediction = if !is_cancelled
2151                    && let Some((prediction_result, requested_by)) = new_prediction_result
2152                {
2153                    match prediction_result.prediction {
2154                        Ok(prediction) => {
2155                            let new_prediction = CurrentEditPrediction {
2156                                requested_by,
2157                                prediction,
2158                                was_shown: false,
2159                                shown_with: None,
2160                                e2e_latency: prediction_result.e2e_latency,
2161                            };
2162
2163                            if let Some(current_prediction) =
2164                                project_state.current_prediction.as_ref()
2165                            {
2166                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2167                                {
2168                                    this.reject_current_prediction(
2169                                        EditPredictionRejectReason::Replaced,
2170                                        &project,
2171                                        cx,
2172                                    );
2173
2174                                    Some(new_prediction)
2175                                } else {
2176                                    this.reject_prediction(
2177                                        new_prediction.prediction.id,
2178                                        EditPredictionRejectReason::CurrentPreferred,
2179                                        false,
2180                                        new_prediction.prediction.model_version,
2181                                        Some(new_prediction.e2e_latency),
2182                                        cx,
2183                                    );
2184                                    None
2185                                }
2186                            } else {
2187                                Some(new_prediction)
2188                            }
2189                        }
2190                        Err(reject_reason) => {
2191                            this.reject_prediction(
2192                                prediction_result.id,
2193                                reject_reason,
2194                                false,
2195                                None,
2196                                Some(prediction_result.e2e_latency),
2197                                cx,
2198                            );
2199                            None
2200                        }
2201                    }
2202                } else {
2203                    None
2204                };
2205
2206                let project_state = this.get_or_init_project(&project, cx);
2207
2208                if let Some(new_prediction) = new_current_prediction {
2209                    project_state.current_prediction = Some(new_prediction);
2210                }
2211
2212                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2213                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2214                    if pending_prediction.id == pending_prediction_id {
2215                        pending_predictions.remove(ix);
2216                        for pending_prediction in pending_predictions.drain(0..ix) {
2217                            project_state.cancel_pending_prediction(pending_prediction, cx)
2218                        }
2219                        break;
2220                    }
2221                }
2222                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2223                cx.notify();
2224            })
2225            .ok();
2226
2227            new_prediction_id
2228        });
2229
2230        if project_state.pending_predictions.len() < max_pending_predictions {
2231            project_state
2232                .pending_predictions
2233                .push(PendingPrediction {
2234                    id: pending_prediction_id,
2235                    task,
2236                    drop_on_cancel,
2237                })
2238                .unwrap();
2239        } else {
2240            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2241            project_state
2242                .pending_predictions
2243                .push(PendingPrediction {
2244                    id: pending_prediction_id,
2245                    task,
2246                    drop_on_cancel,
2247                })
2248                .unwrap();
2249            project_state.cancel_pending_prediction(pending_prediction, cx);
2250        }
2251    }
2252
2253    pub fn request_prediction(
2254        &mut self,
2255        project: &Entity<Project>,
2256        active_buffer: &Entity<Buffer>,
2257        position: language::Anchor,
2258        trigger: PredictEditsRequestTrigger,
2259        cx: &mut Context<Self>,
2260    ) -> Task<Result<Option<EditPredictionResult>>> {
2261        self.request_prediction_internal(
2262            project.clone(),
2263            active_buffer.clone(),
2264            position,
2265            trigger,
2266            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2267            cx,
2268        )
2269    }
2270
2271    fn request_prediction_internal(
2272        &mut self,
2273        project: Entity<Project>,
2274        active_buffer: Entity<Buffer>,
2275        position: language::Anchor,
2276        trigger: PredictEditsRequestTrigger,
2277        allow_jump: bool,
2278        cx: &mut Context<Self>,
2279    ) -> Task<Result<Option<EditPredictionResult>>> {
2280        self.get_or_init_project(&project, cx);
2281        let project_state = self.projects.get(&project.entity_id()).unwrap();
2282        let stored_events = project_state.events(cx);
2283        let has_events = !stored_events.is_empty();
2284        let events: Vec<Arc<zeta_prompt::Event>> =
2285            stored_events.iter().map(|e| e.event.clone()).collect();
2286        let debug_tx = project_state.debug_tx.clone();
2287
2288        let snapshot = active_buffer.read(cx).snapshot();
2289        let cursor_point = position.to_point(&snapshot);
2290        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2291        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2292        let diagnostic_search_range =
2293            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2294
2295        let related_files = self.context_for_project(&project, cx);
2296
2297        let is_open_source = snapshot
2298            .file()
2299            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2300            && events.iter().all(|event| event.in_open_source_repo())
2301            && related_files.iter().all(|file| file.in_open_source_repo);
2302
2303        let can_collect_data = !cfg!(test)
2304            && is_open_source
2305            && self.is_data_collection_enabled(cx)
2306            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2307
2308        let inputs = EditPredictionModelInput {
2309            project: project.clone(),
2310            buffer: active_buffer,
2311            snapshot,
2312            position,
2313            events,
2314            related_files,
2315            trigger,
2316            diagnostic_search_range: diagnostic_search_range,
2317            debug_tx,
2318            can_collect_data,
2319            is_open_source,
2320        };
2321
2322        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2323
2324        let task = match self.edit_prediction_model {
2325            EditPredictionModel::Zeta => {
2326                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2327            }
2328            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2329            EditPredictionModel::Mercury => {
2330                self.mercury
2331                    .request_prediction(inputs, self.credentials_provider.clone(), cx)
2332            }
2333        };
2334
2335        cx.spawn(async move |this, cx| {
2336            let prediction = task.await?;
2337
2338            // Only fall back to diagnostics-based prediction if we got a
2339            // the model had nothing to suggest for the buffer
2340            if prediction.is_none()
2341                && allow_jump
2342                && has_events
2343                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2344            {
2345                this.update(cx, |this, cx| {
2346                    this.refresh_prediction_from_diagnostics(
2347                        project,
2348                        DiagnosticSearchScope::Local,
2349                        cx,
2350                    );
2351                })?;
2352                return anyhow::Ok(None);
2353            }
2354
2355            Ok(prediction)
2356        })
2357    }
2358
2359    pub(crate) async fn next_diagnostic_location(
2360        active_buffer: Entity<Buffer>,
2361        active_buffer_snapshot: &BufferSnapshot,
2362        active_buffer_diagnostic_search_range: Range<Point>,
2363        active_buffer_cursor_point: Point,
2364        project: &Entity<Project>,
2365        cx: &mut AsyncApp,
2366    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2367        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2368            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2369            .flat_map(|(_, _, _, selections)| {
2370                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2371            })
2372            .collect();
2373
2374        let mut jump_location = active_buffer_snapshot
2375            .diagnostic_groups(None)
2376            .into_iter()
2377            .filter_map(|(_, group)| {
2378                let range = &group.entries[group.primary_ix]
2379                    .range
2380                    .to_point(&active_buffer_snapshot);
2381                if range.overlaps(&active_buffer_diagnostic_search_range) {
2382                    return None;
2383                }
2384                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2385                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2386                });
2387                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2388                    <= DIAGNOSTIC_LINES_RANGE;
2389                if near_collaborator && !near_local {
2390                    return None;
2391                }
2392                Some(range.start)
2393            })
2394            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2395            .map(|position| {
2396                (
2397                    active_buffer.clone(),
2398                    active_buffer_snapshot.anchor_before(position),
2399                )
2400            });
2401
2402        if jump_location.is_none() {
2403            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2404                let file = buffer.file()?;
2405
2406                Some(ProjectPath {
2407                    worktree_id: file.worktree_id(cx),
2408                    path: file.path().clone(),
2409                })
2410            });
2411
2412            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2413                project
2414                    .diagnostic_summaries(false, cx)
2415                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2416                    .map(|(path, _, _)| {
2417                        let shared_prefix = path
2418                            .path
2419                            .components()
2420                            .zip(
2421                                active_buffer_path
2422                                    .as_ref()
2423                                    .map(|p| p.path.components())
2424                                    .unwrap_or_default(),
2425                            )
2426                            .take_while(|(a, b)| a == b)
2427                            .count();
2428                        (path, shared_prefix)
2429                    })
2430                    .collect()
2431            });
2432
2433            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2434
2435            for (path, _) in candidates {
2436                let candidate_buffer = project
2437                    .update(cx, |project, cx| project.open_buffer(path, cx))
2438                    .await?;
2439
2440                let (has_collaborators, diagnostic_position) =
2441                    candidate_buffer.read_with(cx, |buffer, _cx| {
2442                        let snapshot = buffer.snapshot();
2443                        let has_collaborators = snapshot
2444                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2445                            .next()
2446                            .is_some();
2447                        let position = buffer
2448                            .buffer_diagnostics(None)
2449                            .into_iter()
2450                            .min_by_key(|entry| entry.diagnostic.severity)
2451                            .map(|entry| entry.range.start);
2452                        (has_collaborators, position)
2453                    });
2454
2455                if has_collaborators {
2456                    continue;
2457                }
2458
2459                if let Some(position) = diagnostic_position {
2460                    jump_location = Some((candidate_buffer, position));
2461                    break;
2462                }
2463            }
2464        }
2465
2466        anyhow::Ok(jump_location)
2467    }
2468
2469    async fn send_raw_llm_request(
2470        request: RawCompletionRequest,
2471        client: Arc<Client>,
2472        custom_url: Option<Arc<Url>>,
2473        llm_token: LlmApiToken,
2474        organization_id: Option<OrganizationId>,
2475        app_version: Version,
2476    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2477        let url = if let Some(custom_url) = custom_url {
2478            custom_url.as_ref().clone()
2479        } else {
2480            client
2481                .http_client()
2482                .build_zed_llm_url("/predict_edits/raw", &[])?
2483        };
2484
2485        Self::send_api_request(
2486            |builder| {
2487                let req = builder
2488                    .uri(url.as_ref())
2489                    .body(serde_json::to_string(&request)?.into());
2490                Ok(req?)
2491            },
2492            client,
2493            llm_token,
2494            organization_id,
2495            app_version,
2496            true,
2497        )
2498        .await
2499    }
2500
2501    pub(crate) async fn send_v3_request(
2502        input: ZetaPromptInput,
2503        client: Arc<Client>,
2504        llm_token: LlmApiToken,
2505        organization_id: Option<OrganizationId>,
2506        app_version: Version,
2507        trigger: PredictEditsRequestTrigger,
2508    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2509        let url = client
2510            .http_client()
2511            .build_zed_llm_url("/predict_edits/v3", &[])?;
2512
2513        let request = PredictEditsV3Request { input, trigger };
2514
2515        let json_bytes = serde_json::to_vec(&request)?;
2516        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2517
2518        Self::send_api_request(
2519            |builder| {
2520                let req = builder
2521                    .uri(url.as_ref())
2522                    .header("Content-Encoding", "zstd")
2523                    .body(compressed.clone().into());
2524                Ok(req?)
2525            },
2526            client,
2527            llm_token,
2528            organization_id,
2529            app_version,
2530            true,
2531        )
2532        .await
2533    }
2534
2535    async fn send_api_request<Res>(
2536        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2537        client: Arc<Client>,
2538        llm_token: LlmApiToken,
2539        organization_id: Option<OrganizationId>,
2540        app_version: Version,
2541        require_auth: bool,
2542    ) -> Result<(Res, Option<EditPredictionUsage>)>
2543    where
2544        Res: DeserializeOwned,
2545    {
2546        let http_client = client.http_client();
2547        let mut token = if require_auth {
2548            Some(
2549                client
2550                    .acquire_llm_token(&llm_token, organization_id.clone())
2551                    .await?,
2552            )
2553        } else {
2554            client
2555                .acquire_llm_token(&llm_token, organization_id.clone())
2556                .await
2557                .ok()
2558        };
2559        let mut did_retry = false;
2560
2561        loop {
2562            let request_builder = http_client::Request::builder().method(Method::POST);
2563
2564            let mut request_builder = request_builder
2565                .header("Content-Type", "application/json")
2566                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2567
2568            // Only add Authorization header if we have a token
2569            if let Some(ref token_value) = token {
2570                request_builder =
2571                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2572            }
2573
2574            let request = build(request_builder)?;
2575
2576            let mut response = http_client.send(request).await?;
2577
2578            if let Some(minimum_required_version) = response
2579                .headers()
2580                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2581                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2582            {
2583                anyhow::ensure!(
2584                    app_version >= minimum_required_version,
2585                    ZedUpdateRequiredError {
2586                        minimum_version: minimum_required_version
2587                    }
2588                );
2589            }
2590
2591            if response.status().is_success() {
2592                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2593
2594                let mut body = Vec::new();
2595                response.body_mut().read_to_end(&mut body).await?;
2596                return Ok((serde_json::from_slice(&body)?, usage));
2597            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2598                did_retry = true;
2599                token = Some(
2600                    client
2601                        .refresh_llm_token(&llm_token, organization_id.clone())
2602                        .await?,
2603                );
2604            } else {
2605                let mut body = String::new();
2606                response.body_mut().read_to_string(&mut body).await?;
2607                anyhow::bail!(
2608                    "Request failed with status: {:?}\nBody: {}",
2609                    response.status(),
2610                    body
2611                );
2612            }
2613        }
2614    }
2615
2616    pub fn refresh_context(
2617        &mut self,
2618        project: &Entity<Project>,
2619        buffer: &Entity<language::Buffer>,
2620        cursor_position: language::Anchor,
2621        cx: &mut Context<Self>,
2622    ) {
2623        self.get_or_init_project(project, cx)
2624            .context
2625            .update(cx, |store, cx| {
2626                store.refresh(buffer.clone(), cursor_position, cx);
2627            });
2628    }
2629
2630    #[cfg(feature = "cli-support")]
2631    pub fn set_context_for_buffer(
2632        &mut self,
2633        project: &Entity<Project>,
2634        related_files: Vec<RelatedFile>,
2635        cx: &mut Context<Self>,
2636    ) {
2637        self.get_or_init_project(project, cx)
2638            .context
2639            .update(cx, |store, cx| {
2640                store.set_related_files(related_files, cx);
2641            });
2642    }
2643
2644    #[cfg(feature = "cli-support")]
2645    pub fn set_recent_paths_for_project(
2646        &mut self,
2647        project: &Entity<Project>,
2648        paths: impl IntoIterator<Item = project::ProjectPath>,
2649        cx: &mut Context<Self>,
2650    ) {
2651        let project_state = self.get_or_init_project(project, cx);
2652        project_state.recent_paths = paths.into_iter().collect();
2653    }
2654
2655    fn is_file_open_source(
2656        &self,
2657        project: &Entity<Project>,
2658        file: &Arc<dyn File>,
2659        cx: &App,
2660    ) -> bool {
2661        if !file.is_local() || file.is_private() {
2662            return false;
2663        }
2664        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2665            return false;
2666        };
2667        project_state
2668            .license_detection_watchers
2669            .get(&file.worktree_id(cx))
2670            .as_ref()
2671            .is_some_and(|watcher| watcher.is_project_open_source())
2672    }
2673
2674    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2675        self.data_collection_choice.is_enabled(cx)
2676    }
2677
2678    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2679        let choice = KeyValueStore::global(cx)
2680            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2681            .log_err()
2682            .flatten();
2683
2684        match choice.as_deref() {
2685            Some("true") => DataCollectionChoice::Enabled,
2686            Some("false") => DataCollectionChoice::Disabled,
2687            Some(_) => {
2688                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2689                DataCollectionChoice::NotAnswered
2690            }
2691            None => DataCollectionChoice::NotAnswered,
2692        }
2693    }
2694
2695    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2696        self.data_collection_choice = self.data_collection_choice.toggle();
2697        let new_choice = self.data_collection_choice;
2698        let is_enabled = new_choice.is_enabled(cx);
2699        let kvp = KeyValueStore::global(cx);
2700        db::write_and_log(cx, move || async move {
2701            kvp.write_kvp(
2702                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2703                is_enabled.to_string(),
2704            )
2705            .await
2706        });
2707    }
2708
2709    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2710        self.shown_predictions.iter()
2711    }
2712
2713    pub fn shown_completions_len(&self) -> usize {
2714        self.shown_predictions.len()
2715    }
2716
2717    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2718        self.rated_predictions.contains(id)
2719    }
2720
2721    pub fn rate_prediction(
2722        &mut self,
2723        prediction: &EditPrediction,
2724        rating: EditPredictionRating,
2725        feedback: String,
2726        cx: &mut Context<Self>,
2727    ) {
2728        let organization = self.user_store.read(cx).current_organization();
2729
2730        self.rated_predictions.insert(prediction.id.clone());
2731
2732        cx.background_spawn({
2733            let client = self.client.clone();
2734            let prediction_id = prediction.id.to_string();
2735            let inputs = serde_json::to_value(&prediction.inputs);
2736            let output = prediction
2737                .edit_preview
2738                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2739            async move {
2740                client
2741                    .cloud_client()
2742                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2743                        organization_id: organization.map(|organization| organization.id.clone()),
2744                        request_id: prediction_id,
2745                        rating: match rating {
2746                            EditPredictionRating::Positive => "positive".to_string(),
2747                            EditPredictionRating::Negative => "negative".to_string(),
2748                        },
2749                        inputs: inputs?,
2750                        output,
2751                        feedback,
2752                    })
2753                    .await?;
2754
2755                anyhow::Ok(())
2756            }
2757        })
2758        .detach_and_log_err(cx);
2759
2760        cx.notify();
2761    }
2762}
2763
2764fn collaborator_edit_overlaps_locality_region(
2765    project_state: &ProjectState,
2766    project: &Entity<Project>,
2767    buffer: &Entity<Buffer>,
2768    snapshot: &BufferSnapshot,
2769    edit_range: &Range<Anchor>,
2770    cx: &App,
2771) -> bool {
2772    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2773        return false;
2774    };
2775
2776    if active_buffer.entity_id() != buffer.entity_id() {
2777        return false;
2778    }
2779
2780    let locality_point_range = expand_context_syntactically_then_linewise(
2781        snapshot,
2782        (position..position).to_point(snapshot),
2783        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2784    );
2785    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2786
2787    edit_range.overlaps(&locality_anchor_range, snapshot)
2788}
2789
2790fn merge_trailing_events_if_needed(
2791    events: &mut VecDeque<StoredEvent>,
2792    end_snapshot: &TextBufferSnapshot,
2793    latest_snapshot: &TextBufferSnapshot,
2794    latest_edit_range: &Range<Anchor>,
2795) {
2796    if let Some(last_event) = events.back() {
2797        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2798            return;
2799        }
2800        if !latest_snapshot
2801            .version
2802            .observed_all(&last_event.new_snapshot_version)
2803        {
2804            return;
2805        }
2806    }
2807
2808    let mut next_old_event = None;
2809    let mut mergeable_count = 0;
2810    for old_event in events.iter().rev() {
2811        if let Some(next_old_event) = next_old_event
2812            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2813        {
2814            break;
2815        }
2816        mergeable_count += 1;
2817        next_old_event = Some(old_event);
2818    }
2819
2820    if mergeable_count <= 1 {
2821        return;
2822    }
2823
2824    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2825    let oldest_event = events_to_merge.peek().unwrap();
2826    let oldest_snapshot = oldest_event.old_snapshot.clone();
2827    let newest_snapshot = end_snapshot;
2828    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2829
2830    for event in events.range(events.len() - mergeable_count + 1..) {
2831        merged_edit_range =
2832            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2833    }
2834
2835    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2836        &oldest_snapshot,
2837        newest_snapshot,
2838        &merged_edit_range,
2839    ) {
2840        let merged_event = match oldest_event.event.as_ref() {
2841            zeta_prompt::Event::BufferChange {
2842                old_path,
2843                path,
2844                in_open_source_repo,
2845                ..
2846            } => StoredEvent {
2847                event: Arc::new(zeta_prompt::Event::BufferChange {
2848                    old_path: old_path.clone(),
2849                    path: path.clone(),
2850                    diff,
2851                    in_open_source_repo: *in_open_source_repo,
2852                    predicted: events_to_merge.all(|e| {
2853                        matches!(
2854                            e.event.as_ref(),
2855                            zeta_prompt::Event::BufferChange {
2856                                predicted: true,
2857                                ..
2858                            }
2859                        )
2860                    }),
2861                }),
2862                old_snapshot: oldest_snapshot.clone(),
2863                new_snapshot_version: newest_snapshot.version.clone(),
2864                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2865                    ..newest_snapshot.anchor_before(edit_range.end),
2866            },
2867        };
2868        events.truncate(events.len() - mergeable_count);
2869        events.push_back(merged_event);
2870    }
2871}
2872
2873fn merge_anchor_ranges(
2874    left: &Range<Anchor>,
2875    right: &Range<Anchor>,
2876    snapshot: &TextBufferSnapshot,
2877) -> Range<Anchor> {
2878    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2879        left.start
2880    } else {
2881        right.start
2882    };
2883    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2884        left.end
2885    } else {
2886        right.end
2887    };
2888    start..end
2889}
2890
2891#[derive(Error, Debug)]
2892#[error(
2893    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2894)]
2895pub struct ZedUpdateRequiredError {
2896    minimum_version: Version,
2897}
2898
2899#[derive(Debug, Clone, Copy)]
2900pub enum DataCollectionChoice {
2901    NotAnswered,
2902    Enabled,
2903    Disabled,
2904}
2905
2906impl DataCollectionChoice {
2907    pub fn is_enabled(self, cx: &App) -> bool {
2908        if cx.is_staff() {
2909            return true;
2910        }
2911        match self {
2912            Self::Enabled => true,
2913            Self::NotAnswered | Self::Disabled => false,
2914        }
2915    }
2916
2917    #[must_use]
2918    pub fn toggle(&self) -> DataCollectionChoice {
2919        match self {
2920            Self::Enabled => Self::Disabled,
2921            Self::Disabled => Self::Enabled,
2922            Self::NotAnswered => Self::Enabled,
2923        }
2924    }
2925}
2926
2927impl From<bool> for DataCollectionChoice {
2928    fn from(value: bool) -> Self {
2929        match value {
2930            true => DataCollectionChoice::Enabled,
2931            false => DataCollectionChoice::Disabled,
2932        }
2933    }
2934}
2935
2936struct ZedPredictUpsell;
2937
2938impl Dismissable for ZedPredictUpsell {
2939    const KEY: &'static str = "dismissed-edit-predict-upsell";
2940
2941    fn dismissed(cx: &App) -> bool {
2942        // To make this backwards compatible with older versions of Zed, we
2943        // check if the user has seen the previous Edit Prediction Onboarding
2944        // before, by checking the data collection choice which was written to
2945        // the database once the user clicked on "Accept and Enable"
2946        let kvp = KeyValueStore::global(cx);
2947        if kvp
2948            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2949            .log_err()
2950            .is_some_and(|s| s.is_some())
2951        {
2952            return true;
2953        }
2954
2955        kvp.read_kvp(Self::KEY)
2956            .log_err()
2957            .is_some_and(|s| s.is_some())
2958    }
2959}
2960
2961pub fn should_show_upsell_modal(cx: &App) -> bool {
2962    !ZedPredictUpsell::dismissed(cx)
2963}
2964
2965pub fn init(cx: &mut App) {
2966    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2967        workspace.register_action(
2968            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2969                ZedPredictModal::toggle(
2970                    workspace,
2971                    workspace.user_store().clone(),
2972                    workspace.client().clone(),
2973                    window,
2974                    cx,
2975                )
2976            },
2977        );
2978
2979        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2980            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2981                settings
2982                    .project
2983                    .all_languages
2984                    .edit_predictions
2985                    .get_or_insert_default()
2986                    .provider = Some(EditPredictionProvider::None)
2987            });
2988        });
2989        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2990            EditPredictionStore::try_global(cx).and_then(|store| {
2991                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2992            })
2993        }
2994
2995        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2996            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2997                copilot_ui::initiate_sign_in(copilot, window, cx);
2998            }
2999        });
3000        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3001            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3002                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3003            }
3004        });
3005        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3006            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3007                copilot_ui::initiate_sign_out(copilot, window, cx);
3008            }
3009        });
3010    })
3011    .detach();
3012}