edit_prediction.rs

   1use anyhow::Result;
   2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
   3use cloud_api_client::LlmApiToken;
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use credentials_provider::CredentialsProvider;
  16use db::kvp::{Dismissable, KeyValueStore};
  17use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  19use futures::{
  20    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  21    channel::mpsc::{self, UnboundedReceiver},
  22    select_biased,
  23};
  24use gpui::BackgroundExecutor;
  25use gpui::http_client::Url;
  26use gpui::{
  27    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  28    http_client::{self, AsyncBody, Method},
  29    prelude::*,
  30};
  31use heapless::Vec as ArrayVec;
  32use language::language_settings::all_language_settings;
  33use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  34use language::{BufferSnapshot, OffsetRangeExt};
  35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{
  40    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  41};
  42use std::collections::{VecDeque, hash_map};
  43use std::env;
  44use text::{AnchorRangeExt, Edit};
  45use workspace::{AppState, Workspace};
  46use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  47
  48use std::mem;
  49use std::ops::Range;
  50use std::path::Path;
  51use std::rc::Rc;
  52use std::str::FromStr as _;
  53use std::sync::Arc;
  54use std::time::{Duration, Instant};
  55
  56use thiserror::Error;
  57use util::{RangeExt as _, ResultExt as _};
  58
  59pub mod cursor_excerpt;
  60pub mod example_spec;
  61pub mod fim;
  62mod license_detection;
  63pub mod mercury;
  64pub mod ollama;
  65mod onboarding_modal;
  66pub mod open_ai_response;
  67mod prediction;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72pub mod open_ai_compatible;
  73mod zed_edit_prediction_delegate;
  74pub mod zeta;
  75
  76#[cfg(test)]
  77mod edit_prediction_tests;
  78
  79use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  80use crate::example_spec::ExampleSpec;
  81use crate::license_detection::LicenseDetectionWatcher;
  82use crate::mercury::Mercury;
  83use crate::onboarding_modal::ZedPredictModal;
  84pub use crate::prediction::EditPrediction;
  85pub use crate::prediction::EditPredictionId;
  86use crate::prediction::EditPredictionResult;
  87pub use capture_example::capture_example;
  88pub use language_model::ApiKeyState;
  89pub use telemetry_events::EditPredictionRating;
  90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  91
  92actions!(
  93    edit_prediction,
  94    [
  95        /// Resets the edit prediction onboarding state.
  96        ResetOnboarding,
  97        /// Clears the edit prediction history.
  98        ClearHistory,
  99    ]
 100);
 101
 102/// Maximum number of events to track.
 103const EVENT_COUNT_MAX: usize = 10;
 104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 105const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 106const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 107const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 108const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 109const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 110const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 111const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 112const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 113
 114pub struct EditPredictionJumpsFeatureFlag;
 115
 116impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 117    const NAME: &'static str = "edit_prediction_jumps";
 118}
 119
 120#[derive(Clone)]
 121struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 122
 123impl Global for EditPredictionStoreGlobal {}
 124
 125/// Configuration for using the raw Zeta2 endpoint.
 126/// When set, the client uses the raw endpoint and constructs the prompt itself.
 127/// The version is also used as the Baseten environment name (lowercased).
 128#[derive(Clone)]
 129pub struct Zeta2RawConfig {
 130    pub model_id: Option<String>,
 131    pub environment: Option<String>,
 132    pub format: ZetaFormat,
 133}
 134
 135pub struct EditPredictionStore {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _fetch_experiments_task: Task<()>,
 140    projects: HashMap<EntityId, ProjectState>,
 141    update_required: bool,
 142    edit_prediction_model: EditPredictionModel,
 143    zeta2_raw_config: Option<Zeta2RawConfig>,
 144    preferred_experiment: Option<String>,
 145    available_experiments: Vec<String>,
 146    pub mercury: Mercury,
 147    data_collection_choice: DataCollectionChoice,
 148    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 149    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 150    shown_predictions: VecDeque<EditPrediction>,
 151    rated_predictions: HashSet<EditPredictionId>,
 152    #[cfg(test)]
 153    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 154    credentials_provider: Arc<dyn CredentialsProvider>,
 155}
 156
 157pub(crate) struct EditPredictionRejectionPayload {
 158    rejection: EditPredictionRejection,
 159    organization_id: Option<OrganizationId>,
 160}
 161
 162#[derive(Copy, Clone, PartialEq, Eq)]
 163pub enum EditPredictionModel {
 164    Zeta,
 165    Fim { format: EditPredictionPromptFormat },
 166    Mercury,
 167}
 168
 169#[derive(Clone)]
 170pub struct EditPredictionModelInput {
 171    project: Entity<Project>,
 172    buffer: Entity<Buffer>,
 173    snapshot: BufferSnapshot,
 174    position: Anchor,
 175    events: Vec<Arc<zeta_prompt::Event>>,
 176    related_files: Vec<RelatedFile>,
 177    trigger: PredictEditsRequestTrigger,
 178    diagnostic_search_range: Range<Point>,
 179    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 180    can_collect_data: bool,
 181    is_open_source: bool,
 182}
 183
 184#[derive(Debug)]
 185pub enum DebugEvent {
 186    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 187    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 188    EditPredictionStarted(EditPredictionStartedDebugEvent),
 189    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 190}
 191
 192#[derive(Debug)]
 193pub struct ContextRetrievalStartedDebugEvent {
 194    pub project_entity_id: EntityId,
 195    pub timestamp: Instant,
 196    pub search_prompt: String,
 197}
 198
 199#[derive(Debug)]
 200pub struct ContextRetrievalFinishedDebugEvent {
 201    pub project_entity_id: EntityId,
 202    pub timestamp: Instant,
 203    pub metadata: Vec<(&'static str, SharedString)>,
 204}
 205
 206#[derive(Debug)]
 207pub struct EditPredictionStartedDebugEvent {
 208    pub buffer: WeakEntity<Buffer>,
 209    pub position: Anchor,
 210    pub prompt: Option<String>,
 211}
 212
 213#[derive(Debug)]
 214pub struct EditPredictionFinishedDebugEvent {
 215    pub buffer: WeakEntity<Buffer>,
 216    pub position: Anchor,
 217    pub model_output: Option<String>,
 218}
 219
 220/// An event with associated metadata for reconstructing buffer state.
 221#[derive(Clone)]
 222pub struct StoredEvent {
 223    pub event: Arc<zeta_prompt::Event>,
 224    pub old_snapshot: TextBufferSnapshot,
 225    pub new_snapshot_version: clock::Global,
 226    pub total_edit_range: Range<Anchor>,
 227}
 228
 229impl StoredEvent {
 230    fn can_merge(
 231        &self,
 232        next_old_event: &StoredEvent,
 233        latest_snapshot: &TextBufferSnapshot,
 234        latest_edit_range: &Range<Anchor>,
 235    ) -> bool {
 236        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 237        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 238            return false;
 239        }
 240        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 241            return false;
 242        }
 243        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 244            return false;
 245        }
 246        if !latest_snapshot
 247            .version
 248            .observed_all(&next_old_event.new_snapshot_version)
 249        {
 250            return false;
 251        }
 252
 253        let a_is_predicted = matches!(
 254            self.event.as_ref(),
 255            zeta_prompt::Event::BufferChange {
 256                predicted: true,
 257                ..
 258            }
 259        );
 260        let b_is_predicted = matches!(
 261            next_old_event.event.as_ref(),
 262            zeta_prompt::Event::BufferChange {
 263                predicted: true,
 264                ..
 265            }
 266        );
 267
 268        // If events come from the same source (both predicted or both manual) then
 269        // we would have coalesced them already.
 270        if a_is_predicted == b_is_predicted {
 271            return false;
 272        }
 273
 274        let left_range = self.total_edit_range.to_point(latest_snapshot);
 275        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 276        let latest_range = latest_edit_range.to_point(latest_snapshot);
 277
 278        // Events near to the latest edit are not merged if their sources differ.
 279        if lines_between_ranges(&left_range, &latest_range)
 280            .min(lines_between_ranges(&right_range, &latest_range))
 281            <= CHANGE_GROUPING_LINE_SPAN
 282        {
 283            return false;
 284        }
 285
 286        // Events that are distant from each other are not merged.
 287        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 288            return false;
 289        }
 290
 291        true
 292    }
 293}
 294
 295fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 296    if left.start > right.end {
 297        return left.start.row - right.end.row;
 298    }
 299    if right.start > left.end {
 300        return right.start.row - left.end.row;
 301    }
 302    0
 303}
 304
 305struct ProjectState {
 306    events: VecDeque<StoredEvent>,
 307    last_event: Option<LastEvent>,
 308    recent_paths: VecDeque<ProjectPath>,
 309    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 310    current_prediction: Option<CurrentEditPrediction>,
 311    next_pending_prediction_id: usize,
 312    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 313    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 314    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 315    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 316    cancelled_predictions: HashSet<usize>,
 317    context: Entity<RelatedExcerptStore>,
 318    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 319    _subscriptions: [gpui::Subscription; 2],
 320    copilot: Option<Entity<Copilot>>,
 321}
 322
 323impl ProjectState {
 324    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 325        self.events
 326            .iter()
 327            .cloned()
 328            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 329                let (one, two) = event.split_by_pause();
 330                let one = one.finalize(&self.license_detection_watchers, cx);
 331                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 332                one.into_iter().chain(two)
 333            }))
 334            .collect()
 335    }
 336
 337    fn cancel_pending_prediction(
 338        &mut self,
 339        pending_prediction: PendingPrediction,
 340        cx: &mut Context<EditPredictionStore>,
 341    ) {
 342        self.cancelled_predictions.insert(pending_prediction.id);
 343
 344        if pending_prediction.drop_on_cancel {
 345            drop(pending_prediction.task);
 346        } else {
 347            cx.spawn(async move |this, cx| {
 348                let Some(prediction_id) = pending_prediction.task.await else {
 349                    return;
 350                };
 351
 352                this.update(cx, |this, cx| {
 353                    this.reject_prediction(
 354                        prediction_id,
 355                        EditPredictionRejectReason::Canceled,
 356                        false,
 357                        None,
 358                        None,
 359                        cx,
 360                    );
 361                })
 362                .ok();
 363            })
 364            .detach()
 365        }
 366    }
 367
 368    fn active_buffer(
 369        &self,
 370        project: &Entity<Project>,
 371        cx: &App,
 372    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 373        let project = project.read(cx);
 374        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 375        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 376        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 377        Some((active_buffer, registered_buffer.last_position))
 378    }
 379}
 380
 381#[derive(Debug, Clone)]
 382struct CurrentEditPrediction {
 383    pub requested_by: PredictionRequestedBy,
 384    pub prediction: EditPrediction,
 385    pub was_shown: bool,
 386    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 387    pub e2e_latency: std::time::Duration,
 388}
 389
 390impl CurrentEditPrediction {
 391    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 392        let Some(new_edits) = self
 393            .prediction
 394            .interpolate(&self.prediction.buffer.read(cx))
 395        else {
 396            return false;
 397        };
 398
 399        if self.prediction.buffer != old_prediction.prediction.buffer {
 400            return true;
 401        }
 402
 403        let Some(old_edits) = old_prediction
 404            .prediction
 405            .interpolate(&old_prediction.prediction.buffer.read(cx))
 406        else {
 407            return true;
 408        };
 409
 410        let requested_by_buffer_id = self.requested_by.buffer_id();
 411
 412        // This reduces the occurrence of UI thrash from replacing edits
 413        //
 414        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 415        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 416            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 417            && old_edits.len() == 1
 418            && new_edits.len() == 1
 419        {
 420            let (old_range, old_text) = &old_edits[0];
 421            let (new_range, new_text) = &new_edits[0];
 422            new_range == old_range && new_text.starts_with(old_text.as_ref())
 423        } else {
 424            true
 425        }
 426    }
 427}
 428
 429#[derive(Debug, Clone)]
 430enum PredictionRequestedBy {
 431    DiagnosticsUpdate,
 432    Buffer(EntityId),
 433}
 434
 435impl PredictionRequestedBy {
 436    pub fn buffer_id(&self) -> Option<EntityId> {
 437        match self {
 438            PredictionRequestedBy::DiagnosticsUpdate => None,
 439            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 440        }
 441    }
 442}
 443
 444const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 445
 446#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 447pub enum DiagnosticSearchScope {
 448    Local,
 449    Global,
 450}
 451
 452#[derive(Debug)]
 453struct PendingPrediction {
 454    id: usize,
 455    task: Task<Option<EditPredictionId>>,
 456    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 457    /// If false, the task is awaited to completion so rejection can be reported.
 458    drop_on_cancel: bool,
 459}
 460
 461/// A prediction from the perspective of a buffer.
 462#[derive(Debug)]
 463enum BufferEditPrediction<'a> {
 464    Local { prediction: &'a EditPrediction },
 465    Jump { prediction: &'a EditPrediction },
 466}
 467
 468#[cfg(test)]
 469impl std::ops::Deref for BufferEditPrediction<'_> {
 470    type Target = EditPrediction;
 471
 472    fn deref(&self) -> &Self::Target {
 473        match self {
 474            BufferEditPrediction::Local { prediction } => prediction,
 475            BufferEditPrediction::Jump { prediction } => prediction,
 476        }
 477    }
 478}
 479
 480#[derive(Clone)]
 481
 482struct PendingSettledPrediction {
 483    request_id: EditPredictionId,
 484    editable_anchor_range: Range<Anchor>,
 485    example: Option<ExampleSpec>,
 486    enqueued_at: Instant,
 487    last_edit_at: Instant,
 488    e2e_latency: std::time::Duration,
 489}
 490
 491struct RegisteredBuffer {
 492    file: Option<Arc<dyn File>>,
 493    snapshot: TextBufferSnapshot,
 494    pending_predictions: Vec<PendingSettledPrediction>,
 495    last_position: Option<Anchor>,
 496    _subscriptions: [gpui::Subscription; 2],
 497}
 498
 499#[derive(Clone)]
 500struct LastEvent {
 501    old_snapshot: TextBufferSnapshot,
 502    new_snapshot: TextBufferSnapshot,
 503    old_file: Option<Arc<dyn File>>,
 504    new_file: Option<Arc<dyn File>>,
 505    latest_edit_range: Range<Anchor>,
 506    total_edit_range: Range<Anchor>,
 507    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 508    predicted: bool,
 509    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 510    last_edit_time: Option<Instant>,
 511}
 512
 513impl LastEvent {
 514    pub fn finalize(
 515        &self,
 516        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 517        cx: &App,
 518    ) -> Option<StoredEvent> {
 519        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 520        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 521
 522        let in_open_source_repo =
 523            [self.new_file.as_ref(), self.old_file.as_ref()]
 524                .iter()
 525                .all(|file| {
 526                    file.is_some_and(|file| {
 527                        license_detection_watchers
 528                            .get(&file.worktree_id(cx))
 529                            .is_some_and(|watcher| watcher.is_project_open_source())
 530                    })
 531                });
 532
 533        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 534            &self.old_snapshot,
 535            &self.new_snapshot,
 536            &self.total_edit_range,
 537        )?;
 538
 539        if path == old_path && diff.is_empty() {
 540            None
 541        } else {
 542            Some(StoredEvent {
 543                event: Arc::new(zeta_prompt::Event::BufferChange {
 544                    old_path,
 545                    path,
 546                    diff,
 547                    in_open_source_repo,
 548                    predicted: self.predicted,
 549                }),
 550                old_snapshot: self.old_snapshot.clone(),
 551                new_snapshot_version: self.new_snapshot.version.clone(),
 552                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 553                    ..self.new_snapshot.anchor_before(edit_range.end),
 554            })
 555        }
 556    }
 557
 558    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 559        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 560            return (self.clone(), None);
 561        };
 562
 563        let total_edit_range_before_pause = self
 564            .total_edit_range_at_last_pause_boundary
 565            .clone()
 566            .unwrap_or_else(|| self.total_edit_range.clone());
 567
 568        let Some(total_edit_range_after_pause) =
 569            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 570        else {
 571            return (self.clone(), None);
 572        };
 573
 574        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 575        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 576
 577        let before = LastEvent {
 578            old_snapshot: self.old_snapshot.clone(),
 579            new_snapshot: boundary_snapshot.clone(),
 580            old_file: self.old_file.clone(),
 581            new_file: self.new_file.clone(),
 582            latest_edit_range: latest_edit_range_before_pause,
 583            total_edit_range: total_edit_range_before_pause,
 584            total_edit_range_at_last_pause_boundary: None,
 585            predicted: self.predicted,
 586            snapshot_after_last_editing_pause: None,
 587            last_edit_time: self.last_edit_time,
 588        };
 589
 590        let after = LastEvent {
 591            old_snapshot: boundary_snapshot.clone(),
 592            new_snapshot: self.new_snapshot.clone(),
 593            old_file: self.old_file.clone(),
 594            new_file: self.new_file.clone(),
 595            latest_edit_range: latest_edit_range_after_pause,
 596            total_edit_range: total_edit_range_after_pause,
 597            total_edit_range_at_last_pause_boundary: None,
 598            predicted: self.predicted,
 599            snapshot_after_last_editing_pause: None,
 600            last_edit_time: self.last_edit_time,
 601        };
 602
 603        (before, Some(after))
 604    }
 605}
 606
 607fn compute_total_edit_range_between_snapshots(
 608    old_snapshot: &TextBufferSnapshot,
 609    new_snapshot: &TextBufferSnapshot,
 610) -> Option<Range<Anchor>> {
 611    let edits: Vec<Edit<usize>> = new_snapshot
 612        .edits_since::<usize>(&old_snapshot.version)
 613        .collect();
 614
 615    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 616    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 617    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 618
 619    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 620}
 621
 622fn compute_old_range_for_new_range(
 623    old_snapshot: &TextBufferSnapshot,
 624    new_snapshot: &TextBufferSnapshot,
 625    total_edit_range: &Range<Anchor>,
 626) -> Option<Range<Point>> {
 627    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 628    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 629
 630    let edits: Vec<Edit<usize>> = new_snapshot
 631        .edits_since::<usize>(&old_snapshot.version)
 632        .collect();
 633    let mut old_start_offset = None;
 634    let mut old_end_offset = None;
 635    let mut delta: isize = 0;
 636
 637    for edit in &edits {
 638        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 639            old_start_offset = Some(if new_start_offset < edit.new.start {
 640                new_start_offset.checked_add_signed(-delta)?
 641            } else {
 642                edit.old.start
 643            });
 644        }
 645
 646        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 647            old_end_offset = Some(if new_end_offset < edit.new.start {
 648                new_end_offset.checked_add_signed(-delta)?
 649            } else {
 650                edit.old.end
 651            });
 652        }
 653
 654        delta += edit.new.len() as isize - edit.old.len() as isize;
 655    }
 656
 657    let old_start_offset =
 658        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 659    let old_end_offset =
 660        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 661
 662    Some(
 663        old_snapshot.offset_to_point(old_start_offset)
 664            ..old_snapshot.offset_to_point(old_end_offset),
 665    )
 666}
 667
 668fn compute_diff_between_snapshots_in_range(
 669    old_snapshot: &TextBufferSnapshot,
 670    new_snapshot: &TextBufferSnapshot,
 671    total_edit_range: &Range<Anchor>,
 672) -> Option<(String, Range<Point>)> {
 673    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 674    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 675    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 676    let old_start_point = old_range.start;
 677    let old_end_point = old_range.end;
 678
 679    const CONTEXT_LINES: u32 = 3;
 680
 681    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 682    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 683    let old_context_end_row =
 684        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 685    let new_context_end_row =
 686        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 687
 688    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 689    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 690    let old_end_line_offset = old_snapshot
 691        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 692    let new_end_line_offset = new_snapshot
 693        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 694    let old_edit_range = old_start_line_offset..old_end_line_offset;
 695    let new_edit_range = new_start_line_offset..new_end_line_offset;
 696
 697    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 698        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 699    {
 700        return None;
 701    }
 702
 703    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 704    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 705
 706    let diff = language::unified_diff_with_offsets(
 707        &old_region_text,
 708        &new_region_text,
 709        old_context_start_row,
 710        new_context_start_row,
 711    );
 712
 713    Some((diff, new_start_point..new_end_point))
 714}
 715
 716fn buffer_path_with_id_fallback(
 717    file: Option<&Arc<dyn File>>,
 718    snapshot: &TextBufferSnapshot,
 719    cx: &App,
 720) -> Arc<Path> {
 721    if let Some(file) = file {
 722        file.full_path(cx).into()
 723    } else {
 724        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 725    }
 726}
 727
 728impl EditPredictionStore {
 729    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 730        cx.try_global::<EditPredictionStoreGlobal>()
 731            .map(|global| global.0.clone())
 732    }
 733
 734    pub fn global(
 735        client: &Arc<Client>,
 736        user_store: &Entity<UserStore>,
 737        cx: &mut App,
 738    ) -> Entity<Self> {
 739        cx.try_global::<EditPredictionStoreGlobal>()
 740            .map(|global| global.0.clone())
 741            .unwrap_or_else(|| {
 742                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 743                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 744                ep_store
 745            })
 746    }
 747
 748    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 749        let data_collection_choice = Self::load_data_collection_choice(cx);
 750
 751        let llm_token = global_llm_token(cx);
 752
 753        let (reject_tx, reject_rx) = mpsc::unbounded();
 754        cx.background_spawn({
 755            let client = client.clone();
 756            let llm_token = llm_token.clone();
 757            let app_version = AppVersion::global(cx);
 758            let background_executor = cx.background_executor().clone();
 759            async move {
 760                Self::handle_rejected_predictions(
 761                    reject_rx,
 762                    client,
 763                    llm_token,
 764                    app_version,
 765                    background_executor,
 766                )
 767                .await
 768            }
 769        })
 770        .detach();
 771
 772        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 773        cx.spawn(async move |this, cx| {
 774            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 775        })
 776        .detach();
 777
 778        let mut current_user = user_store.read(cx).watch_current_user();
 779        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 780            while current_user.borrow().is_none() {
 781                current_user.next().await;
 782            }
 783
 784            this.update(cx, |this, cx| {
 785                if cx.is_staff() {
 786                    this.refresh_available_experiments(cx);
 787                }
 788            })
 789            .log_err();
 790        });
 791
 792        let credentials_provider = zed_credentials_provider::global(cx);
 793
 794        let this = Self {
 795            projects: HashMap::default(),
 796            client,
 797            user_store,
 798            llm_token,
 799            _fetch_experiments_task: fetch_experiments_task,
 800            update_required: false,
 801            edit_prediction_model: EditPredictionModel::Zeta,
 802            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 803            preferred_experiment: None,
 804            available_experiments: Vec::new(),
 805            mercury: Mercury::new(cx),
 806
 807            data_collection_choice,
 808            reject_predictions_tx: reject_tx,
 809            settled_predictions_tx,
 810            rated_predictions: Default::default(),
 811            shown_predictions: Default::default(),
 812            #[cfg(test)]
 813            settled_event_callback: None,
 814
 815            credentials_provider,
 816        };
 817
 818        this
 819    }
 820
 821    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 822        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 823        let format = ZetaFormat::parse(&version_str).ok()?;
 824        let model_id = env::var("ZED_ZETA_MODEL").ok();
 825        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 826        Some(Zeta2RawConfig {
 827            model_id,
 828            environment,
 829            format,
 830        })
 831    }
 832
 833    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 834        self.edit_prediction_model = model;
 835    }
 836
 837    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 838        self.zeta2_raw_config = Some(config);
 839    }
 840
 841    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 842        self.zeta2_raw_config.as_ref()
 843    }
 844
 845    pub fn preferred_experiment(&self) -> Option<&str> {
 846        self.preferred_experiment.as_deref()
 847    }
 848
 849    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 850        self.preferred_experiment = experiment;
 851    }
 852
 853    pub fn available_experiments(&self) -> &[String] {
 854        &self.available_experiments
 855    }
 856
 857    pub fn active_experiment(&self) -> Option<&str> {
 858        self.preferred_experiment.as_deref().or_else(|| {
 859            self.shown_predictions
 860                .iter()
 861                .find_map(|p| p.model_version.as_ref())
 862                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 863        })
 864    }
 865
 866    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 867        let client = self.client.clone();
 868        let llm_token = self.llm_token.clone();
 869        let app_version = AppVersion::global(cx);
 870        let organization_id = self
 871            .user_store
 872            .read(cx)
 873            .current_organization()
 874            .map(|organization| organization.id.clone());
 875
 876        cx.spawn(async move |this, cx| {
 877            let experiments = cx
 878                .background_spawn(async move {
 879                    let http_client = client.http_client();
 880                    let token = client
 881                        .acquire_llm_token(&llm_token, organization_id.clone())
 882                        .await?;
 883                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 884                    let request = http_client::Request::builder()
 885                        .method(Method::GET)
 886                        .uri(url.as_ref())
 887                        .header("Authorization", format!("Bearer {}", token))
 888                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 889                        .body(Default::default())?;
 890                    let mut response = http_client.send(request).await?;
 891                    if response.status().is_success() {
 892                        let mut body = Vec::new();
 893                        response.body_mut().read_to_end(&mut body).await?;
 894                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 895                        Ok(experiments)
 896                    } else {
 897                        let mut body = String::new();
 898                        response.body_mut().read_to_string(&mut body).await?;
 899                        anyhow::bail!(
 900                            "Failed to fetch experiments: {:?}\nBody: {}",
 901                            response.status(),
 902                            body
 903                        );
 904                    }
 905                })
 906                .await?;
 907            this.update(cx, |this, cx| {
 908                this.available_experiments = experiments;
 909                cx.notify();
 910            })?;
 911            anyhow::Ok(())
 912        })
 913        .detach_and_log_err(cx);
 914    }
 915
 916    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 917        use ui::IconName;
 918        match self.edit_prediction_model {
 919            EditPredictionModel::Mercury => {
 920                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 921            }
 922            EditPredictionModel::Zeta => {
 923                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 924                    .with_disabled(IconName::ZedPredictDisabled)
 925                    .with_up(IconName::ZedPredictUp)
 926                    .with_down(IconName::ZedPredictDown)
 927                    .with_error(IconName::ZedPredictError)
 928            }
 929            EditPredictionModel::Fim { .. } => {
 930                let settings = &all_language_settings(None, cx).edit_predictions;
 931                match settings.provider {
 932                    EditPredictionProvider::Ollama => {
 933                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 934                    }
 935                    _ => {
 936                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 937                    }
 938                }
 939            }
 940        }
 941    }
 942
 943    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 944        self.mercury.api_token.read(cx).has_key()
 945    }
 946
 947    pub fn mercury_has_payment_required_error(&self) -> bool {
 948        self.mercury.has_payment_required_error()
 949    }
 950
 951    pub fn clear_history(&mut self) {
 952        for project_state in self.projects.values_mut() {
 953            project_state.events.clear();
 954            project_state.last_event.take();
 955        }
 956    }
 957
 958    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 959        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 960            project_state.events.clear();
 961            project_state.last_event.take();
 962        }
 963    }
 964
 965    pub fn edit_history_for_project(
 966        &self,
 967        project: &Entity<Project>,
 968        cx: &App,
 969    ) -> Vec<StoredEvent> {
 970        self.projects
 971            .get(&project.entity_id())
 972            .map(|project_state| project_state.events(cx))
 973            .unwrap_or_default()
 974    }
 975
 976    pub fn context_for_project<'a>(
 977        &'a self,
 978        project: &Entity<Project>,
 979        cx: &'a mut App,
 980    ) -> Vec<RelatedFile> {
 981        self.projects
 982            .get(&project.entity_id())
 983            .map(|project_state| {
 984                project_state.context.update(cx, |context, cx| {
 985                    context
 986                        .related_files_with_buffers(cx)
 987                        .map(|(mut related_file, buffer)| {
 988                            related_file.in_open_source_repo = buffer
 989                                .read(cx)
 990                                .file()
 991                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 992                            related_file
 993                        })
 994                        .collect()
 995                })
 996            })
 997            .unwrap_or_default()
 998    }
 999
1000    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1001        self.projects
1002            .get(&project.entity_id())
1003            .and_then(|project| project.copilot.clone())
1004    }
1005
1006    pub fn start_copilot_for_project(
1007        &mut self,
1008        project: &Entity<Project>,
1009        cx: &mut Context<Self>,
1010    ) -> Option<Entity<Copilot>> {
1011        if DisableAiSettings::get(None, cx).disable_ai {
1012            return None;
1013        }
1014        let state = self.get_or_init_project(project, cx);
1015
1016        if state.copilot.is_some() {
1017            return state.copilot.clone();
1018        }
1019        let _project = project.clone();
1020        let project = project.read(cx);
1021
1022        let node = project.node_runtime().cloned();
1023        if let Some(node) = node {
1024            let next_id = project.languages().next_language_server_id();
1025            let fs = project.fs().clone();
1026
1027            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1028            state.copilot = Some(copilot.clone());
1029            Some(copilot)
1030        } else {
1031            None
1032        }
1033    }
1034
1035    pub fn context_for_project_with_buffers<'a>(
1036        &'a self,
1037        project: &Entity<Project>,
1038        cx: &'a mut App,
1039    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1040        self.projects
1041            .get(&project.entity_id())
1042            .map(|project| {
1043                project.context.update(cx, |context, cx| {
1044                    context.related_files_with_buffers(cx).collect()
1045                })
1046            })
1047            .unwrap_or_default()
1048    }
1049
1050    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1051        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1052            self.user_store.read(cx).edit_prediction_usage()
1053        } else {
1054            None
1055        }
1056    }
1057
1058    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1059        self.get_or_init_project(project, cx);
1060    }
1061
1062    pub fn register_buffer(
1063        &mut self,
1064        buffer: &Entity<Buffer>,
1065        project: &Entity<Project>,
1066        cx: &mut Context<Self>,
1067    ) {
1068        let project_state = self.get_or_init_project(project, cx);
1069        Self::register_buffer_impl(project_state, buffer, project, cx);
1070    }
1071
1072    fn get_or_init_project(
1073        &mut self,
1074        project: &Entity<Project>,
1075        cx: &mut Context<Self>,
1076    ) -> &mut ProjectState {
1077        let entity_id = project.entity_id();
1078        self.projects
1079            .entry(entity_id)
1080            .or_insert_with(|| ProjectState {
1081                context: {
1082                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1083                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1084                        this.handle_excerpt_store_event(entity_id, event);
1085                    })
1086                    .detach();
1087                    related_excerpt_store
1088                },
1089                events: VecDeque::new(),
1090                last_event: None,
1091                recent_paths: VecDeque::new(),
1092                debug_tx: None,
1093                registered_buffers: HashMap::default(),
1094                current_prediction: None,
1095                cancelled_predictions: HashSet::default(),
1096                pending_predictions: ArrayVec::new(),
1097                next_pending_prediction_id: 0,
1098                last_edit_prediction_refresh: None,
1099                last_jump_prediction_refresh: None,
1100                license_detection_watchers: HashMap::default(),
1101                _subscriptions: [
1102                    cx.subscribe(&project, Self::handle_project_event),
1103                    cx.observe_release(&project, move |this, _, cx| {
1104                        this.projects.remove(&entity_id);
1105                        cx.notify();
1106                    }),
1107                ],
1108                copilot: None,
1109            })
1110    }
1111
1112    pub fn remove_project(&mut self, project: &Entity<Project>) {
1113        self.projects.remove(&project.entity_id());
1114    }
1115
1116    fn handle_excerpt_store_event(
1117        &mut self,
1118        project_entity_id: EntityId,
1119        event: &RelatedExcerptStoreEvent,
1120    ) {
1121        if let Some(project_state) = self.projects.get(&project_entity_id) {
1122            if let Some(debug_tx) = project_state.debug_tx.clone() {
1123                match event {
1124                    RelatedExcerptStoreEvent::StartedRefresh => {
1125                        debug_tx
1126                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1127                                ContextRetrievalStartedDebugEvent {
1128                                    project_entity_id: project_entity_id,
1129                                    timestamp: Instant::now(),
1130                                    search_prompt: String::new(),
1131                                },
1132                            ))
1133                            .ok();
1134                    }
1135                    RelatedExcerptStoreEvent::FinishedRefresh {
1136                        cache_hit_count,
1137                        cache_miss_count,
1138                        mean_definition_latency,
1139                        max_definition_latency,
1140                    } => {
1141                        debug_tx
1142                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1143                                ContextRetrievalFinishedDebugEvent {
1144                                    project_entity_id: project_entity_id,
1145                                    timestamp: Instant::now(),
1146                                    metadata: vec![
1147                                        (
1148                                            "Cache Hits",
1149                                            format!(
1150                                                "{}/{}",
1151                                                cache_hit_count,
1152                                                cache_hit_count + cache_miss_count
1153                                            )
1154                                            .into(),
1155                                        ),
1156                                        (
1157                                            "Max LSP Time",
1158                                            format!("{} ms", max_definition_latency.as_millis())
1159                                                .into(),
1160                                        ),
1161                                        (
1162                                            "Mean LSP Time",
1163                                            format!("{} ms", mean_definition_latency.as_millis())
1164                                                .into(),
1165                                        ),
1166                                    ],
1167                                },
1168                            ))
1169                            .ok();
1170                    }
1171                }
1172            }
1173        }
1174    }
1175
1176    pub fn debug_info(
1177        &mut self,
1178        project: &Entity<Project>,
1179        cx: &mut Context<Self>,
1180    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1181        let project_state = self.get_or_init_project(project, cx);
1182        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1183        project_state.debug_tx = Some(debug_watch_tx);
1184        debug_watch_rx
1185    }
1186
1187    fn handle_project_event(
1188        &mut self,
1189        project: Entity<Project>,
1190        event: &project::Event,
1191        cx: &mut Context<Self>,
1192    ) {
1193        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1194            return;
1195        }
1196        // TODO [zeta2] init with recent paths
1197        match event {
1198            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1199                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1200                    return;
1201                };
1202                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1203                if let Some(path) = path {
1204                    if let Some(ix) = project_state
1205                        .recent_paths
1206                        .iter()
1207                        .position(|probe| probe == &path)
1208                    {
1209                        project_state.recent_paths.remove(ix);
1210                    }
1211                    project_state.recent_paths.push_front(path);
1212                }
1213            }
1214            project::Event::DiagnosticsUpdated { .. } => {
1215                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1216                    self.refresh_prediction_from_diagnostics(
1217                        project,
1218                        DiagnosticSearchScope::Global,
1219                        cx,
1220                    );
1221                }
1222            }
1223            _ => (),
1224        }
1225    }
1226
1227    fn register_buffer_impl<'a>(
1228        project_state: &'a mut ProjectState,
1229        buffer: &Entity<Buffer>,
1230        project: &Entity<Project>,
1231        cx: &mut Context<Self>,
1232    ) -> &'a mut RegisteredBuffer {
1233        let buffer_id = buffer.entity_id();
1234
1235        if let Some(file) = buffer.read(cx).file() {
1236            let worktree_id = file.worktree_id(cx);
1237            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1238                project_state
1239                    .license_detection_watchers
1240                    .entry(worktree_id)
1241                    .or_insert_with(|| {
1242                        let project_entity_id = project.entity_id();
1243                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1244                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1245                            else {
1246                                return;
1247                            };
1248                            project_state
1249                                .license_detection_watchers
1250                                .remove(&worktree_id);
1251                        })
1252                        .detach();
1253                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1254                    });
1255            }
1256        }
1257
1258        match project_state.registered_buffers.entry(buffer_id) {
1259            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1260            hash_map::Entry::Vacant(entry) => {
1261                let buf = buffer.read(cx);
1262                let snapshot = buf.text_snapshot();
1263                let file = buf.file().cloned();
1264                let project_entity_id = project.entity_id();
1265                entry.insert(RegisteredBuffer {
1266                    snapshot,
1267                    file,
1268                    last_position: None,
1269                    pending_predictions: Vec::new(),
1270                    _subscriptions: [
1271                        cx.subscribe(buffer, {
1272                            let project = project.downgrade();
1273                            move |this, buffer, event, cx| {
1274                                if let language::BufferEvent::Edited { is_local } = event
1275                                    && let Some(project) = project.upgrade()
1276                                {
1277                                    this.report_changes_for_buffer(
1278                                        &buffer, &project, false, *is_local, cx,
1279                                    );
1280                                }
1281                            }
1282                        }),
1283                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1284                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1285                            else {
1286                                return;
1287                            };
1288                            project_state.registered_buffers.remove(&buffer_id);
1289                        }),
1290                    ],
1291                })
1292            }
1293        }
1294    }
1295
1296    fn report_changes_for_buffer(
1297        &mut self,
1298        buffer: &Entity<Buffer>,
1299        project: &Entity<Project>,
1300        is_predicted: bool,
1301        is_local: bool,
1302        cx: &mut Context<Self>,
1303    ) {
1304        let project_state = self.get_or_init_project(project, cx);
1305        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1306
1307        let buf = buffer.read(cx);
1308        let new_file = buf.file().cloned();
1309        let new_snapshot = buf.text_snapshot();
1310        if new_snapshot.version == registered_buffer.snapshot.version {
1311            return;
1312        }
1313        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1314        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1315        let mut edit_range: Option<Range<Anchor>> = None;
1316        let now = cx.background_executor().now();
1317
1318        for (_edit, anchor_range) in
1319            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1320        {
1321            edit_range = Some(match edit_range {
1322                None => anchor_range,
1323                Some(acc) => acc.start..anchor_range.end,
1324            });
1325        }
1326
1327        let Some(edit_range) = edit_range else {
1328            return;
1329        };
1330
1331        for pending_prediction in &mut registered_buffer.pending_predictions {
1332            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1333                pending_prediction.last_edit_at = now;
1334            }
1335        }
1336
1337        let include_in_history = is_local
1338            || collaborator_edit_overlaps_locality_region(
1339                project_state,
1340                project,
1341                buffer,
1342                &buf.snapshot(),
1343                &edit_range,
1344                cx,
1345            );
1346
1347        if !include_in_history {
1348            return;
1349        }
1350
1351        let is_recordable_history_edit =
1352            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1353                .is_some();
1354
1355        let events = &mut project_state.events;
1356
1357        if !is_recordable_history_edit {
1358            if let Some(event) = project_state.last_event.take() {
1359                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1360                    if events.len() + 1 >= EVENT_COUNT_MAX {
1361                        events.pop_front();
1362                    }
1363                    events.push_back(event);
1364                }
1365            }
1366            return;
1367        }
1368
1369        if let Some(last_event) = project_state.last_event.as_mut() {
1370            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1371                == last_event.new_snapshot.remote_id()
1372                && old_snapshot.version == last_event.new_snapshot.version;
1373
1374            let prediction_source_changed = is_predicted != last_event.predicted;
1375
1376            let should_coalesce = is_next_snapshot_of_same_buffer
1377                && !prediction_source_changed
1378                && lines_between_ranges(
1379                    &edit_range.to_point(&new_snapshot),
1380                    &last_event.latest_edit_range.to_point(&new_snapshot),
1381                ) <= CHANGE_GROUPING_LINE_SPAN;
1382
1383            if should_coalesce {
1384                let pause_elapsed = last_event
1385                    .last_edit_time
1386                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1387                    .unwrap_or(false);
1388                if pause_elapsed {
1389                    last_event.snapshot_after_last_editing_pause =
1390                        Some(last_event.new_snapshot.clone());
1391                    last_event.total_edit_range_at_last_pause_boundary =
1392                        Some(last_event.total_edit_range.clone());
1393                }
1394
1395                last_event.latest_edit_range = edit_range.clone();
1396                last_event.total_edit_range =
1397                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1398                last_event.new_snapshot = new_snapshot;
1399                last_event.last_edit_time = Some(now);
1400                return;
1401            }
1402        }
1403
1404        if let Some(event) = project_state.last_event.take() {
1405            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1406                if events.len() + 1 >= EVENT_COUNT_MAX {
1407                    events.pop_front();
1408                }
1409                events.push_back(event);
1410            }
1411        }
1412
1413        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1414
1415        project_state.last_event = Some(LastEvent {
1416            old_file,
1417            new_file,
1418            old_snapshot,
1419            new_snapshot,
1420            latest_edit_range: edit_range.clone(),
1421            total_edit_range: edit_range,
1422            total_edit_range_at_last_pause_boundary: None,
1423            predicted: is_predicted,
1424            snapshot_after_last_editing_pause: None,
1425            last_edit_time: Some(now),
1426        });
1427    }
1428
1429    fn prediction_at(
1430        &mut self,
1431        buffer: &Entity<Buffer>,
1432        position: Option<language::Anchor>,
1433        project: &Entity<Project>,
1434        cx: &App,
1435    ) -> Option<BufferEditPrediction<'_>> {
1436        let project_state = self.projects.get_mut(&project.entity_id())?;
1437        if let Some(position) = position
1438            && let Some(buffer) = project_state
1439                .registered_buffers
1440                .get_mut(&buffer.entity_id())
1441        {
1442            buffer.last_position = Some(position);
1443        }
1444
1445        let CurrentEditPrediction {
1446            requested_by,
1447            prediction,
1448            ..
1449        } = project_state.current_prediction.as_ref()?;
1450
1451        if prediction.targets_buffer(buffer.read(cx)) {
1452            Some(BufferEditPrediction::Local { prediction })
1453        } else {
1454            let show_jump = match requested_by {
1455                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1456                    requested_by_buffer_id == &buffer.entity_id()
1457                }
1458                PredictionRequestedBy::DiagnosticsUpdate => true,
1459            };
1460
1461            if show_jump {
1462                Some(BufferEditPrediction::Jump { prediction })
1463            } else {
1464                None
1465            }
1466        }
1467    }
1468
1469    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1470        let Some(current_prediction) = self
1471            .projects
1472            .get_mut(&project.entity_id())
1473            .and_then(|project_state| project_state.current_prediction.take())
1474        else {
1475            return;
1476        };
1477
1478        self.report_changes_for_buffer(
1479            &current_prediction.prediction.buffer,
1480            project,
1481            true,
1482            true,
1483            cx,
1484        );
1485
1486        // can't hold &mut project_state ref across report_changes_for_buffer_call
1487        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1488            return;
1489        };
1490
1491        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1492            project_state.cancel_pending_prediction(pending_prediction, cx);
1493        }
1494
1495        match self.edit_prediction_model {
1496            EditPredictionModel::Mercury => {
1497                mercury::edit_prediction_accepted(
1498                    current_prediction.prediction.id,
1499                    self.client.http_client(),
1500                    cx,
1501                );
1502            }
1503            EditPredictionModel::Zeta => {
1504                let is_cloud = !matches!(
1505                    all_language_settings(None, cx).edit_predictions.provider,
1506                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1507                );
1508                if is_cloud {
1509                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1510                }
1511            }
1512            EditPredictionModel::Fim { .. } => {}
1513        }
1514    }
1515
1516    async fn handle_rejected_predictions(
1517        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1518        client: Arc<Client>,
1519        llm_token: LlmApiToken,
1520        app_version: Version,
1521        background_executor: BackgroundExecutor,
1522    ) {
1523        let mut rx = std::pin::pin!(rx.peekable());
1524        let mut batched = Vec::new();
1525
1526        while let Some(EditPredictionRejectionPayload {
1527            rejection,
1528            organization_id,
1529        }) = rx.next().await
1530        {
1531            batched.push(rejection);
1532
1533            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1534                select_biased! {
1535                    next = rx.as_mut().peek().fuse() => {
1536                        if next.is_some() {
1537                            continue;
1538                        }
1539                    }
1540                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1541                }
1542            }
1543
1544            let url = client
1545                .http_client()
1546                .build_zed_llm_url("/predict_edits/reject", &[])
1547                .unwrap();
1548
1549            let flush_count = batched
1550                .len()
1551                // in case items have accumulated after failure
1552                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1553            let start = batched.len() - flush_count;
1554
1555            let body = RejectEditPredictionsBodyRef {
1556                rejections: &batched[start..],
1557            };
1558
1559            let result = Self::send_api_request::<()>(
1560                |builder| {
1561                    let req = builder
1562                        .uri(url.as_ref())
1563                        .body(serde_json::to_string(&body)?.into());
1564                    anyhow::Ok(req?)
1565                },
1566                client.clone(),
1567                llm_token.clone(),
1568                organization_id,
1569                app_version.clone(),
1570                true,
1571            )
1572            .await;
1573
1574            if result.log_err().is_some() {
1575                batched.drain(start..);
1576            }
1577        }
1578    }
1579
1580    async fn run_settled_predictions_worker(
1581        this: WeakEntity<Self>,
1582        mut rx: UnboundedReceiver<Instant>,
1583        cx: &mut AsyncApp,
1584    ) {
1585        let mut next_wake_time: Option<Instant> = None;
1586        loop {
1587            let now = cx.background_executor().now();
1588            if let Some(wake_time) = next_wake_time.take() {
1589                cx.background_executor()
1590                    .timer(wake_time.duration_since(now))
1591                    .await;
1592            } else {
1593                let Some(new_enqueue_time) = rx.next().await else {
1594                    break;
1595                };
1596                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1597                while rx.next().now_or_never().flatten().is_some() {}
1598                continue;
1599            }
1600
1601            let Some(this) = this.upgrade() else {
1602                break;
1603            };
1604
1605            let now = cx.background_executor().now();
1606
1607            let mut oldest_edited_at = None;
1608
1609            this.update(cx, |this, _| {
1610                for (_, project_state) in this.projects.iter_mut() {
1611                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1612                        registered_buffer
1613                            .pending_predictions
1614                            .retain_mut(|pending_prediction| {
1615                                let age =
1616                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1617                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1618                                    return false;
1619                                }
1620
1621                                let quiet_for =
1622                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1623                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1624                                    let settled_editable_region = registered_buffer
1625                                        .snapshot
1626                                        .text_for_range(
1627                                            pending_prediction.editable_anchor_range.clone(),
1628                                        )
1629                                        .collect::<String>();
1630
1631                                    #[cfg(test)]
1632                                    if let Some(callback) = &this.settled_event_callback {
1633                                        callback(
1634                                            pending_prediction.request_id.clone(),
1635                                            settled_editable_region.clone(),
1636                                        );
1637                                    }
1638
1639                                    telemetry::event!(
1640                                        EDIT_PREDICTION_SETTLED_EVENT,
1641                                        request_id = pending_prediction.request_id.0.clone(),
1642                                        settled_editable_region,
1643                                        example = pending_prediction.example.take(),
1644                                        e2e_latency = pending_prediction.e2e_latency.as_millis(),
1645                                    );
1646
1647                                    return false;
1648                                }
1649
1650                                if oldest_edited_at
1651                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1652                                {
1653                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1654                                }
1655
1656                                true
1657                            });
1658                    }
1659                }
1660            });
1661
1662            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1663        }
1664    }
1665
1666    pub(crate) fn enqueue_settled_prediction(
1667        &mut self,
1668        request_id: EditPredictionId,
1669        project: &Entity<Project>,
1670        edited_buffer: &Entity<Buffer>,
1671        edited_buffer_snapshot: &BufferSnapshot,
1672        editable_offset_range: Range<usize>,
1673        example: Option<ExampleSpec>,
1674        e2e_latency: std::time::Duration,
1675        cx: &mut Context<Self>,
1676    ) {
1677        let this = &mut *self;
1678        let project_state = this.get_or_init_project(project, cx);
1679        if let Some(buffer) = project_state
1680            .registered_buffers
1681            .get_mut(&edited_buffer.entity_id())
1682        {
1683            let now = cx.background_executor().now();
1684            buffer.pending_predictions.push(PendingSettledPrediction {
1685                request_id: request_id,
1686                editable_anchor_range: edited_buffer_snapshot
1687                    .anchor_range_inside(editable_offset_range),
1688                example,
1689                e2e_latency,
1690                enqueued_at: now,
1691                last_edit_at: now,
1692            });
1693            this.settled_predictions_tx.unbounded_send(now).ok();
1694        }
1695    }
1696
1697    fn reject_current_prediction(
1698        &mut self,
1699        reason: EditPredictionRejectReason,
1700        project: &Entity<Project>,
1701        cx: &App,
1702    ) {
1703        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1704            project_state.pending_predictions.clear();
1705            if let Some(prediction) = project_state.current_prediction.take() {
1706                let model_version = prediction.prediction.model_version.clone();
1707                self.reject_prediction(
1708                    prediction.prediction.id,
1709                    reason,
1710                    prediction.was_shown,
1711                    model_version,
1712                    Some(prediction.e2e_latency),
1713                    cx,
1714                );
1715            }
1716        };
1717    }
1718
1719    fn did_show_current_prediction(
1720        &mut self,
1721        project: &Entity<Project>,
1722        display_type: edit_prediction_types::SuggestionDisplayType,
1723        _cx: &mut Context<Self>,
1724    ) {
1725        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1726            return;
1727        };
1728
1729        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1730            return;
1731        };
1732
1733        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1734        let previous_shown_with = current_prediction.shown_with;
1735
1736        if previous_shown_with.is_none() || !is_jump {
1737            current_prediction.shown_with = Some(display_type);
1738        }
1739
1740        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1741
1742        if is_first_non_jump_show {
1743            current_prediction.was_shown = true;
1744        }
1745
1746        if is_first_non_jump_show {
1747            self.shown_predictions
1748                .push_front(current_prediction.prediction.clone());
1749            if self.shown_predictions.len() > 50 {
1750                let completion = self.shown_predictions.pop_back().unwrap();
1751                self.rated_predictions.remove(&completion.id);
1752            }
1753        }
1754    }
1755
1756    fn reject_prediction(
1757        &mut self,
1758        prediction_id: EditPredictionId,
1759        reason: EditPredictionRejectReason,
1760        was_shown: bool,
1761        model_version: Option<String>,
1762        e2e_latency: Option<std::time::Duration>,
1763        cx: &App,
1764    ) {
1765        match self.edit_prediction_model {
1766            EditPredictionModel::Zeta => {
1767                let is_cloud = !matches!(
1768                    all_language_settings(None, cx).edit_predictions.provider,
1769                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1770                );
1771
1772                if is_cloud {
1773                    let organization_id = self
1774                        .user_store
1775                        .read(cx)
1776                        .current_organization()
1777                        .map(|organization| organization.id.clone());
1778
1779                    self.reject_predictions_tx
1780                        .unbounded_send(EditPredictionRejectionPayload {
1781                            rejection: EditPredictionRejection {
1782                                request_id: prediction_id.to_string(),
1783                                reason,
1784                                was_shown,
1785                                model_version,
1786                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1787                            },
1788                            organization_id,
1789                        })
1790                        .log_err();
1791                }
1792            }
1793            EditPredictionModel::Mercury => {
1794                mercury::edit_prediction_rejected(
1795                    prediction_id,
1796                    was_shown,
1797                    reason,
1798                    self.client.http_client(),
1799                    cx,
1800                );
1801            }
1802            EditPredictionModel::Fim { .. } => {}
1803        }
1804    }
1805
1806    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1807        self.projects
1808            .get(&project.entity_id())
1809            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1810    }
1811
1812    pub fn refresh_prediction_from_buffer(
1813        &mut self,
1814        project: Entity<Project>,
1815        buffer: Entity<Buffer>,
1816        position: language::Anchor,
1817        cx: &mut Context<Self>,
1818    ) {
1819        self.queue_prediction_refresh(
1820            project.clone(),
1821            PredictEditsRequestTrigger::Other,
1822            buffer.entity_id(),
1823            cx,
1824            move |this, cx| {
1825                let Some(request_task) = this
1826                    .update(cx, |this, cx| {
1827                        this.request_prediction(
1828                            &project,
1829                            &buffer,
1830                            position,
1831                            PredictEditsRequestTrigger::Other,
1832                            cx,
1833                        )
1834                    })
1835                    .log_err()
1836                else {
1837                    return Task::ready(anyhow::Ok(None));
1838                };
1839
1840                cx.spawn(async move |_cx| {
1841                    request_task.await.map(|prediction_result| {
1842                        prediction_result.map(|prediction_result| {
1843                            (
1844                                prediction_result,
1845                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1846                            )
1847                        })
1848                    })
1849                })
1850            },
1851        )
1852    }
1853
1854    pub fn refresh_prediction_from_diagnostics(
1855        &mut self,
1856        project: Entity<Project>,
1857        scope: DiagnosticSearchScope,
1858        cx: &mut Context<Self>,
1859    ) {
1860        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1861            return;
1862        }
1863
1864        if currently_following(&project, cx) {
1865            return;
1866        }
1867
1868        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1869            return;
1870        };
1871
1872        // Prefer predictions from buffer
1873        if project_state.current_prediction.is_some() {
1874            log::debug!(
1875                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1876            );
1877            return;
1878        }
1879
1880        self.queue_prediction_refresh(
1881            project.clone(),
1882            PredictEditsRequestTrigger::Diagnostics,
1883            project.entity_id(),
1884            cx,
1885            move |this, cx| {
1886                let Some((active_buffer, snapshot, cursor_point)) = this
1887                    .read_with(cx, |this, cx| {
1888                        let project_state = this.projects.get(&project.entity_id())?;
1889                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1890                        let snapshot = buffer.read(cx).snapshot();
1891
1892                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1893                            return None;
1894                        }
1895
1896                        let cursor_point = position
1897                            .map(|pos| pos.to_point(&snapshot))
1898                            .unwrap_or_default();
1899
1900                        Some((buffer, snapshot, cursor_point))
1901                    })
1902                    .log_err()
1903                    .flatten()
1904                else {
1905                    return Task::ready(anyhow::Ok(None));
1906                };
1907
1908                cx.spawn(async move |cx| {
1909                    let diagnostic_search_range = match scope {
1910                        DiagnosticSearchScope::Local => {
1911                            let diagnostic_search_start =
1912                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1913                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1914                            Point::new(diagnostic_search_start, 0)
1915                                ..Point::new(diagnostic_search_end, 0)
1916                        }
1917                        DiagnosticSearchScope::Global => Default::default(),
1918                    };
1919
1920                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1921                        active_buffer,
1922                        &snapshot,
1923                        diagnostic_search_range,
1924                        cursor_point,
1925                        &project,
1926                        cx,
1927                    )
1928                    .await?
1929                    else {
1930                        return anyhow::Ok(None);
1931                    };
1932
1933                    let Some(prediction_result) = this
1934                        .update(cx, |this, cx| {
1935                            this.request_prediction(
1936                                &project,
1937                                &jump_buffer,
1938                                jump_position,
1939                                PredictEditsRequestTrigger::Diagnostics,
1940                                cx,
1941                            )
1942                        })?
1943                        .await?
1944                    else {
1945                        return anyhow::Ok(None);
1946                    };
1947
1948                    this.update(cx, |this, cx| {
1949                        Some((
1950                            if this
1951                                .get_or_init_project(&project, cx)
1952                                .current_prediction
1953                                .is_none()
1954                            {
1955                                prediction_result
1956                            } else {
1957                                EditPredictionResult {
1958                                    id: prediction_result.id,
1959                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1960                                    e2e_latency: prediction_result.e2e_latency,
1961                                }
1962                            },
1963                            PredictionRequestedBy::DiagnosticsUpdate,
1964                        ))
1965                    })
1966                })
1967            },
1968        );
1969    }
1970
1971    fn predictions_enabled_at(
1972        snapshot: &BufferSnapshot,
1973        position: Option<language::Anchor>,
1974        cx: &App,
1975    ) -> bool {
1976        let file = snapshot.file();
1977        let all_settings = all_language_settings(file, cx);
1978        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1979            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1980        {
1981            return false;
1982        }
1983
1984        if let Some(last_position) = position {
1985            let settings = snapshot.settings_at(last_position, cx);
1986
1987            if !settings.edit_predictions_disabled_in.is_empty()
1988                && let Some(scope) = snapshot.language_scope_at(last_position)
1989                && let Some(scope_name) = scope.override_name()
1990                && settings
1991                    .edit_predictions_disabled_in
1992                    .iter()
1993                    .any(|s| s == scope_name)
1994            {
1995                return false;
1996            }
1997        }
1998
1999        true
2000    }
2001
2002    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2003}
2004
2005fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2006    let Some(app_state) = AppState::try_global(cx) else {
2007        return false;
2008    };
2009
2010    app_state
2011        .workspace_store
2012        .read(cx)
2013        .workspaces()
2014        .filter_map(|workspace| workspace.upgrade())
2015        .any(|workspace| {
2016            workspace.read(cx).project().entity_id() == project.entity_id()
2017                && workspace
2018                    .read(cx)
2019                    .leader_for_pane(workspace.read(cx).active_pane())
2020                    .is_some()
2021        })
2022}
2023
2024fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2025    match provider {
2026        EditPredictionProvider::Zed
2027        | EditPredictionProvider::Mercury
2028        | EditPredictionProvider::Ollama
2029        | EditPredictionProvider::OpenAiCompatibleApi
2030        | EditPredictionProvider::Experimental(_) => true,
2031        EditPredictionProvider::None
2032        | EditPredictionProvider::Copilot
2033        | EditPredictionProvider::Codestral => false,
2034    }
2035}
2036
2037impl EditPredictionStore {
2038    fn queue_prediction_refresh(
2039        &mut self,
2040        project: Entity<Project>,
2041        request_trigger: PredictEditsRequestTrigger,
2042        throttle_entity: EntityId,
2043        cx: &mut Context<Self>,
2044        do_refresh: impl FnOnce(
2045            WeakEntity<Self>,
2046            &mut AsyncApp,
2047        )
2048            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2049        + 'static,
2050    ) {
2051        fn select_throttle(
2052            project_state: &mut ProjectState,
2053            request_trigger: PredictEditsRequestTrigger,
2054        ) -> &mut Option<(EntityId, Instant)> {
2055            match request_trigger {
2056                PredictEditsRequestTrigger::Diagnostics => {
2057                    &mut project_state.last_jump_prediction_refresh
2058                }
2059                _ => &mut project_state.last_edit_prediction_refresh,
2060            }
2061        }
2062
2063        let (needs_acceptance_tracking, max_pending_predictions) =
2064            match all_language_settings(None, cx).edit_predictions.provider {
2065                EditPredictionProvider::Zed
2066                | EditPredictionProvider::Mercury
2067                | EditPredictionProvider::Experimental(_) => (true, 2),
2068                EditPredictionProvider::Ollama => (false, 1),
2069                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2070                EditPredictionProvider::None
2071                | EditPredictionProvider::Copilot
2072                | EditPredictionProvider::Codestral => {
2073                    log::error!("queue_prediction_refresh called with non-store provider");
2074                    return;
2075                }
2076            };
2077
2078        let drop_on_cancel = !needs_acceptance_tracking;
2079        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2080        let project_state = self.get_or_init_project(&project, cx);
2081        let pending_prediction_id = project_state.next_pending_prediction_id;
2082        project_state.next_pending_prediction_id += 1;
2083        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2084
2085        let task = cx.spawn(async move |this, cx| {
2086            let throttle_wait = this
2087                .update(cx, |this, cx| {
2088                    let project_state = this.get_or_init_project(&project, cx);
2089                    let throttle = *select_throttle(project_state, request_trigger);
2090
2091                    let now = cx.background_executor().now();
2092                    throttle.and_then(|(last_entity, last_timestamp)| {
2093                        if throttle_entity != last_entity {
2094                            return None;
2095                        }
2096                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2097                    })
2098                })
2099                .ok()
2100                .flatten();
2101
2102            if let Some(timeout) = throttle_wait {
2103                cx.background_executor().timer(timeout).await;
2104            }
2105
2106            // If this task was cancelled before the throttle timeout expired,
2107            // do not perform a request. Also skip if another task already
2108            // proceeded since we were enqueued (duplicate).
2109            let mut is_cancelled = true;
2110            this.update(cx, |this, cx| {
2111                let project_state = this.get_or_init_project(&project, cx);
2112                let was_cancelled = project_state
2113                    .cancelled_predictions
2114                    .remove(&pending_prediction_id);
2115                if was_cancelled {
2116                    return;
2117                }
2118
2119                // Another request has been already sent since this was enqueued
2120                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2121                    return;
2122                }
2123
2124                let new_refresh = (throttle_entity, cx.background_executor().now());
2125                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2126                is_cancelled = false;
2127            })
2128            .ok();
2129            if is_cancelled {
2130                return None;
2131            }
2132
2133            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2134            let new_prediction_id = new_prediction_result
2135                .as_ref()
2136                .map(|(prediction, _)| prediction.id.clone());
2137
2138            // When a prediction completes, remove it from the pending list, and cancel
2139            // any pending predictions that were enqueued before it.
2140            this.update(cx, |this, cx| {
2141                let project_state = this.get_or_init_project(&project, cx);
2142
2143                let is_cancelled = project_state
2144                    .cancelled_predictions
2145                    .remove(&pending_prediction_id);
2146
2147                let new_current_prediction = if !is_cancelled
2148                    && let Some((prediction_result, requested_by)) = new_prediction_result
2149                {
2150                    match prediction_result.prediction {
2151                        Ok(prediction) => {
2152                            let new_prediction = CurrentEditPrediction {
2153                                requested_by,
2154                                prediction,
2155                                was_shown: false,
2156                                shown_with: None,
2157                                e2e_latency: prediction_result.e2e_latency,
2158                            };
2159
2160                            if let Some(current_prediction) =
2161                                project_state.current_prediction.as_ref()
2162                            {
2163                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2164                                {
2165                                    this.reject_current_prediction(
2166                                        EditPredictionRejectReason::Replaced,
2167                                        &project,
2168                                        cx,
2169                                    );
2170
2171                                    Some(new_prediction)
2172                                } else {
2173                                    this.reject_prediction(
2174                                        new_prediction.prediction.id,
2175                                        EditPredictionRejectReason::CurrentPreferred,
2176                                        false,
2177                                        new_prediction.prediction.model_version,
2178                                        Some(new_prediction.e2e_latency),
2179                                        cx,
2180                                    );
2181                                    None
2182                                }
2183                            } else {
2184                                Some(new_prediction)
2185                            }
2186                        }
2187                        Err(reject_reason) => {
2188                            this.reject_prediction(
2189                                prediction_result.id,
2190                                reject_reason,
2191                                false,
2192                                None,
2193                                Some(prediction_result.e2e_latency),
2194                                cx,
2195                            );
2196                            None
2197                        }
2198                    }
2199                } else {
2200                    None
2201                };
2202
2203                let project_state = this.get_or_init_project(&project, cx);
2204
2205                if let Some(new_prediction) = new_current_prediction {
2206                    project_state.current_prediction = Some(new_prediction);
2207                }
2208
2209                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2210                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2211                    if pending_prediction.id == pending_prediction_id {
2212                        pending_predictions.remove(ix);
2213                        for pending_prediction in pending_predictions.drain(0..ix) {
2214                            project_state.cancel_pending_prediction(pending_prediction, cx)
2215                        }
2216                        break;
2217                    }
2218                }
2219                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2220                cx.notify();
2221            })
2222            .ok();
2223
2224            new_prediction_id
2225        });
2226
2227        if project_state.pending_predictions.len() < max_pending_predictions {
2228            project_state
2229                .pending_predictions
2230                .push(PendingPrediction {
2231                    id: pending_prediction_id,
2232                    task,
2233                    drop_on_cancel,
2234                })
2235                .unwrap();
2236        } else {
2237            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2238            project_state
2239                .pending_predictions
2240                .push(PendingPrediction {
2241                    id: pending_prediction_id,
2242                    task,
2243                    drop_on_cancel,
2244                })
2245                .unwrap();
2246            project_state.cancel_pending_prediction(pending_prediction, cx);
2247        }
2248    }
2249
2250    pub fn request_prediction(
2251        &mut self,
2252        project: &Entity<Project>,
2253        active_buffer: &Entity<Buffer>,
2254        position: language::Anchor,
2255        trigger: PredictEditsRequestTrigger,
2256        cx: &mut Context<Self>,
2257    ) -> Task<Result<Option<EditPredictionResult>>> {
2258        self.request_prediction_internal(
2259            project.clone(),
2260            active_buffer.clone(),
2261            position,
2262            trigger,
2263            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2264            cx,
2265        )
2266    }
2267
2268    fn request_prediction_internal(
2269        &mut self,
2270        project: Entity<Project>,
2271        active_buffer: Entity<Buffer>,
2272        position: language::Anchor,
2273        trigger: PredictEditsRequestTrigger,
2274        allow_jump: bool,
2275        cx: &mut Context<Self>,
2276    ) -> Task<Result<Option<EditPredictionResult>>> {
2277        self.get_or_init_project(&project, cx);
2278        let project_state = self.projects.get(&project.entity_id()).unwrap();
2279        let stored_events = project_state.events(cx);
2280        let has_events = !stored_events.is_empty();
2281        let events: Vec<Arc<zeta_prompt::Event>> =
2282            stored_events.iter().map(|e| e.event.clone()).collect();
2283        let debug_tx = project_state.debug_tx.clone();
2284
2285        let snapshot = active_buffer.read(cx).snapshot();
2286        let cursor_point = position.to_point(&snapshot);
2287        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2288        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2289        let diagnostic_search_range =
2290            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2291
2292        let related_files = self.context_for_project(&project, cx);
2293
2294        let is_open_source = snapshot
2295            .file()
2296            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2297            && events.iter().all(|event| event.in_open_source_repo())
2298            && related_files.iter().all(|file| file.in_open_source_repo);
2299
2300        let can_collect_data = !cfg!(test)
2301            && is_open_source
2302            && self.is_data_collection_enabled(cx)
2303            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2304
2305        let inputs = EditPredictionModelInput {
2306            project: project.clone(),
2307            buffer: active_buffer,
2308            snapshot,
2309            position,
2310            events,
2311            related_files,
2312            trigger,
2313            diagnostic_search_range: diagnostic_search_range,
2314            debug_tx,
2315            can_collect_data,
2316            is_open_source,
2317        };
2318
2319        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2320
2321        let task = match self.edit_prediction_model {
2322            EditPredictionModel::Zeta => {
2323                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2324            }
2325            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2326            EditPredictionModel::Mercury => {
2327                self.mercury
2328                    .request_prediction(inputs, self.credentials_provider.clone(), cx)
2329            }
2330        };
2331
2332        cx.spawn(async move |this, cx| {
2333            let prediction = task.await?;
2334
2335            // Only fall back to diagnostics-based prediction if we got a
2336            // the model had nothing to suggest for the buffer
2337            if prediction.is_none()
2338                && allow_jump
2339                && has_events
2340                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2341            {
2342                this.update(cx, |this, cx| {
2343                    this.refresh_prediction_from_diagnostics(
2344                        project,
2345                        DiagnosticSearchScope::Local,
2346                        cx,
2347                    );
2348                })?;
2349                return anyhow::Ok(None);
2350            }
2351
2352            Ok(prediction)
2353        })
2354    }
2355
2356    pub(crate) async fn next_diagnostic_location(
2357        active_buffer: Entity<Buffer>,
2358        active_buffer_snapshot: &BufferSnapshot,
2359        active_buffer_diagnostic_search_range: Range<Point>,
2360        active_buffer_cursor_point: Point,
2361        project: &Entity<Project>,
2362        cx: &mut AsyncApp,
2363    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2364        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2365            .selections_in_range(
2366                Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2367                false,
2368            )
2369            .flat_map(|(_, _, _, selections)| {
2370                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2371            })
2372            .collect();
2373
2374        let mut jump_location = active_buffer_snapshot
2375            .diagnostic_groups(None)
2376            .into_iter()
2377            .filter_map(|(_, group)| {
2378                let range = &group.entries[group.primary_ix]
2379                    .range
2380                    .to_point(&active_buffer_snapshot);
2381                if range.overlaps(&active_buffer_diagnostic_search_range) {
2382                    return None;
2383                }
2384                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2385                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2386                });
2387                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2388                    <= DIAGNOSTIC_LINES_RANGE;
2389                if near_collaborator && !near_local {
2390                    return None;
2391                }
2392                Some(range.start)
2393            })
2394            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2395            .map(|position| {
2396                (
2397                    active_buffer.clone(),
2398                    active_buffer_snapshot.anchor_before(position),
2399                )
2400            });
2401
2402        if jump_location.is_none() {
2403            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2404                let file = buffer.file()?;
2405
2406                Some(ProjectPath {
2407                    worktree_id: file.worktree_id(cx),
2408                    path: file.path().clone(),
2409                })
2410            });
2411
2412            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2413                project
2414                    .diagnostic_summaries(false, cx)
2415                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2416                    .map(|(path, _, _)| {
2417                        let shared_prefix = path
2418                            .path
2419                            .components()
2420                            .zip(
2421                                active_buffer_path
2422                                    .as_ref()
2423                                    .map(|p| p.path.components())
2424                                    .unwrap_or_default(),
2425                            )
2426                            .take_while(|(a, b)| a == b)
2427                            .count();
2428                        (path, shared_prefix)
2429                    })
2430                    .collect()
2431            });
2432
2433            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2434
2435            for (path, _) in candidates {
2436                let candidate_buffer = project
2437                    .update(cx, |project, cx| project.open_buffer(path, cx))
2438                    .await?;
2439
2440                let (has_collaborators, diagnostic_position) =
2441                    candidate_buffer.read_with(cx, |buffer, _cx| {
2442                        let snapshot = buffer.snapshot();
2443                        let has_collaborators = snapshot
2444                            .selections_in_range(
2445                                Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2446                                false,
2447                            )
2448                            .next()
2449                            .is_some();
2450                        let position = buffer
2451                            .buffer_diagnostics(None)
2452                            .into_iter()
2453                            .min_by_key(|entry| entry.diagnostic.severity)
2454                            .map(|entry| entry.range.start);
2455                        (has_collaborators, position)
2456                    });
2457
2458                if has_collaborators {
2459                    continue;
2460                }
2461
2462                if let Some(position) = diagnostic_position {
2463                    jump_location = Some((candidate_buffer, position));
2464                    break;
2465                }
2466            }
2467        }
2468
2469        anyhow::Ok(jump_location)
2470    }
2471
2472    async fn send_raw_llm_request(
2473        request: RawCompletionRequest,
2474        client: Arc<Client>,
2475        custom_url: Option<Arc<Url>>,
2476        llm_token: LlmApiToken,
2477        organization_id: Option<OrganizationId>,
2478        app_version: Version,
2479    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2480        let url = if let Some(custom_url) = custom_url {
2481            custom_url.as_ref().clone()
2482        } else {
2483            client
2484                .http_client()
2485                .build_zed_llm_url("/predict_edits/raw", &[])?
2486        };
2487
2488        Self::send_api_request(
2489            |builder| {
2490                let req = builder
2491                    .uri(url.as_ref())
2492                    .body(serde_json::to_string(&request)?.into());
2493                Ok(req?)
2494            },
2495            client,
2496            llm_token,
2497            organization_id,
2498            app_version,
2499            true,
2500        )
2501        .await
2502    }
2503
2504    pub(crate) async fn send_v3_request(
2505        input: ZetaPromptInput,
2506        client: Arc<Client>,
2507        llm_token: LlmApiToken,
2508        organization_id: Option<OrganizationId>,
2509        app_version: Version,
2510        trigger: PredictEditsRequestTrigger,
2511    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2512        let url = client
2513            .http_client()
2514            .build_zed_llm_url("/predict_edits/v3", &[])?;
2515
2516        let request = PredictEditsV3Request { input, trigger };
2517
2518        let json_bytes = serde_json::to_vec(&request)?;
2519        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2520
2521        Self::send_api_request(
2522            |builder| {
2523                let req = builder
2524                    .uri(url.as_ref())
2525                    .header("Content-Encoding", "zstd")
2526                    .body(compressed.clone().into());
2527                Ok(req?)
2528            },
2529            client,
2530            llm_token,
2531            organization_id,
2532            app_version,
2533            true,
2534        )
2535        .await
2536    }
2537
2538    async fn send_api_request<Res>(
2539        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2540        client: Arc<Client>,
2541        llm_token: LlmApiToken,
2542        organization_id: Option<OrganizationId>,
2543        app_version: Version,
2544        require_auth: bool,
2545    ) -> Result<(Res, Option<EditPredictionUsage>)>
2546    where
2547        Res: DeserializeOwned,
2548    {
2549        let http_client = client.http_client();
2550        let mut token = if require_auth {
2551            Some(
2552                client
2553                    .acquire_llm_token(&llm_token, organization_id.clone())
2554                    .await?,
2555            )
2556        } else {
2557            client
2558                .acquire_llm_token(&llm_token, organization_id.clone())
2559                .await
2560                .ok()
2561        };
2562        let mut did_retry = false;
2563
2564        loop {
2565            let request_builder = http_client::Request::builder().method(Method::POST);
2566
2567            let mut request_builder = request_builder
2568                .header("Content-Type", "application/json")
2569                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2570
2571            // Only add Authorization header if we have a token
2572            if let Some(ref token_value) = token {
2573                request_builder =
2574                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2575            }
2576
2577            let request = build(request_builder)?;
2578
2579            let mut response = http_client.send(request).await?;
2580
2581            if let Some(minimum_required_version) = response
2582                .headers()
2583                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2584                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2585            {
2586                anyhow::ensure!(
2587                    app_version >= minimum_required_version,
2588                    ZedUpdateRequiredError {
2589                        minimum_version: minimum_required_version
2590                    }
2591                );
2592            }
2593
2594            if response.status().is_success() {
2595                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2596
2597                let mut body = Vec::new();
2598                response.body_mut().read_to_end(&mut body).await?;
2599                return Ok((serde_json::from_slice(&body)?, usage));
2600            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2601                did_retry = true;
2602                token = Some(
2603                    client
2604                        .refresh_llm_token(&llm_token, organization_id.clone())
2605                        .await?,
2606                );
2607            } else {
2608                let mut body = String::new();
2609                response.body_mut().read_to_string(&mut body).await?;
2610                anyhow::bail!(
2611                    "Request failed with status: {:?}\nBody: {}",
2612                    response.status(),
2613                    body
2614                );
2615            }
2616        }
2617    }
2618
2619    pub fn refresh_context(
2620        &mut self,
2621        project: &Entity<Project>,
2622        buffer: &Entity<language::Buffer>,
2623        cursor_position: language::Anchor,
2624        cx: &mut Context<Self>,
2625    ) {
2626        self.get_or_init_project(project, cx)
2627            .context
2628            .update(cx, |store, cx| {
2629                store.refresh(buffer.clone(), cursor_position, cx);
2630            });
2631    }
2632
2633    #[cfg(feature = "cli-support")]
2634    pub fn set_context_for_buffer(
2635        &mut self,
2636        project: &Entity<Project>,
2637        related_files: Vec<RelatedFile>,
2638        cx: &mut Context<Self>,
2639    ) {
2640        self.get_or_init_project(project, cx)
2641            .context
2642            .update(cx, |store, cx| {
2643                store.set_related_files(related_files, cx);
2644            });
2645    }
2646
2647    #[cfg(feature = "cli-support")]
2648    pub fn set_recent_paths_for_project(
2649        &mut self,
2650        project: &Entity<Project>,
2651        paths: impl IntoIterator<Item = project::ProjectPath>,
2652        cx: &mut Context<Self>,
2653    ) {
2654        let project_state = self.get_or_init_project(project, cx);
2655        project_state.recent_paths = paths.into_iter().collect();
2656    }
2657
2658    fn is_file_open_source(
2659        &self,
2660        project: &Entity<Project>,
2661        file: &Arc<dyn File>,
2662        cx: &App,
2663    ) -> bool {
2664        if !file.is_local() || file.is_private() {
2665            return false;
2666        }
2667        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2668            return false;
2669        };
2670        project_state
2671            .license_detection_watchers
2672            .get(&file.worktree_id(cx))
2673            .as_ref()
2674            .is_some_and(|watcher| watcher.is_project_open_source())
2675    }
2676
2677    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2678        self.data_collection_choice.is_enabled(cx)
2679    }
2680
2681    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2682        let choice = KeyValueStore::global(cx)
2683            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2684            .log_err()
2685            .flatten();
2686
2687        match choice.as_deref() {
2688            Some("true") => DataCollectionChoice::Enabled,
2689            Some("false") => DataCollectionChoice::Disabled,
2690            Some(_) => {
2691                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2692                DataCollectionChoice::NotAnswered
2693            }
2694            None => DataCollectionChoice::NotAnswered,
2695        }
2696    }
2697
2698    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2699        self.data_collection_choice = self.data_collection_choice.toggle();
2700        let new_choice = self.data_collection_choice;
2701        let is_enabled = new_choice.is_enabled(cx);
2702        let kvp = KeyValueStore::global(cx);
2703        db::write_and_log(cx, move || async move {
2704            kvp.write_kvp(
2705                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2706                is_enabled.to_string(),
2707            )
2708            .await
2709        });
2710    }
2711
2712    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2713        self.shown_predictions.iter()
2714    }
2715
2716    pub fn shown_completions_len(&self) -> usize {
2717        self.shown_predictions.len()
2718    }
2719
2720    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2721        self.rated_predictions.contains(id)
2722    }
2723
2724    pub fn rate_prediction(
2725        &mut self,
2726        prediction: &EditPrediction,
2727        rating: EditPredictionRating,
2728        feedback: String,
2729        cx: &mut Context<Self>,
2730    ) {
2731        let organization = self.user_store.read(cx).current_organization();
2732
2733        self.rated_predictions.insert(prediction.id.clone());
2734
2735        cx.background_spawn({
2736            let client = self.client.clone();
2737            let prediction_id = prediction.id.to_string();
2738            let inputs = serde_json::to_value(&prediction.inputs);
2739            let output = prediction
2740                .edit_preview
2741                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2742            async move {
2743                client
2744                    .cloud_client()
2745                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2746                        organization_id: organization.map(|organization| organization.id.clone()),
2747                        request_id: prediction_id,
2748                        rating: match rating {
2749                            EditPredictionRating::Positive => "positive".to_string(),
2750                            EditPredictionRating::Negative => "negative".to_string(),
2751                        },
2752                        inputs: inputs?,
2753                        output,
2754                        feedback,
2755                    })
2756                    .await?;
2757
2758                anyhow::Ok(())
2759            }
2760        })
2761        .detach_and_log_err(cx);
2762
2763        cx.notify();
2764    }
2765}
2766
2767fn collaborator_edit_overlaps_locality_region(
2768    project_state: &ProjectState,
2769    project: &Entity<Project>,
2770    buffer: &Entity<Buffer>,
2771    snapshot: &BufferSnapshot,
2772    edit_range: &Range<Anchor>,
2773    cx: &App,
2774) -> bool {
2775    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2776        return false;
2777    };
2778
2779    if active_buffer.entity_id() != buffer.entity_id() {
2780        return false;
2781    }
2782
2783    let locality_point_range = expand_context_syntactically_then_linewise(
2784        snapshot,
2785        (position..position).to_point(snapshot),
2786        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2787    );
2788    let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2789
2790    edit_range.overlaps(&locality_anchor_range, snapshot)
2791}
2792
2793fn merge_trailing_events_if_needed(
2794    events: &mut VecDeque<StoredEvent>,
2795    end_snapshot: &TextBufferSnapshot,
2796    latest_snapshot: &TextBufferSnapshot,
2797    latest_edit_range: &Range<Anchor>,
2798) {
2799    if let Some(last_event) = events.back() {
2800        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2801            return;
2802        }
2803        if !latest_snapshot
2804            .version
2805            .observed_all(&last_event.new_snapshot_version)
2806        {
2807            return;
2808        }
2809    }
2810
2811    let mut next_old_event = None;
2812    let mut mergeable_count = 0;
2813    for old_event in events.iter().rev() {
2814        if let Some(next_old_event) = next_old_event
2815            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2816        {
2817            break;
2818        }
2819        mergeable_count += 1;
2820        next_old_event = Some(old_event);
2821    }
2822
2823    if mergeable_count <= 1 {
2824        return;
2825    }
2826
2827    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2828    let oldest_event = events_to_merge.peek().unwrap();
2829    let oldest_snapshot = oldest_event.old_snapshot.clone();
2830    let newest_snapshot = end_snapshot;
2831    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2832
2833    for event in events.range(events.len() - mergeable_count + 1..) {
2834        merged_edit_range =
2835            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2836    }
2837
2838    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2839        &oldest_snapshot,
2840        newest_snapshot,
2841        &merged_edit_range,
2842    ) {
2843        let merged_event = match oldest_event.event.as_ref() {
2844            zeta_prompt::Event::BufferChange {
2845                old_path,
2846                path,
2847                in_open_source_repo,
2848                ..
2849            } => StoredEvent {
2850                event: Arc::new(zeta_prompt::Event::BufferChange {
2851                    old_path: old_path.clone(),
2852                    path: path.clone(),
2853                    diff,
2854                    in_open_source_repo: *in_open_source_repo,
2855                    predicted: events_to_merge.all(|e| {
2856                        matches!(
2857                            e.event.as_ref(),
2858                            zeta_prompt::Event::BufferChange {
2859                                predicted: true,
2860                                ..
2861                            }
2862                        )
2863                    }),
2864                }),
2865                old_snapshot: oldest_snapshot.clone(),
2866                new_snapshot_version: newest_snapshot.version.clone(),
2867                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2868                    ..newest_snapshot.anchor_before(edit_range.end),
2869            },
2870        };
2871        events.truncate(events.len() - mergeable_count);
2872        events.push_back(merged_event);
2873    }
2874}
2875
2876fn merge_anchor_ranges(
2877    left: &Range<Anchor>,
2878    right: &Range<Anchor>,
2879    snapshot: &TextBufferSnapshot,
2880) -> Range<Anchor> {
2881    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2882        left.start
2883    } else {
2884        right.start
2885    };
2886    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2887        left.end
2888    } else {
2889        right.end
2890    };
2891    start..end
2892}
2893
2894#[derive(Error, Debug)]
2895#[error(
2896    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2897)]
2898pub struct ZedUpdateRequiredError {
2899    minimum_version: Version,
2900}
2901
2902#[derive(Debug, Clone, Copy)]
2903pub enum DataCollectionChoice {
2904    NotAnswered,
2905    Enabled,
2906    Disabled,
2907}
2908
2909impl DataCollectionChoice {
2910    pub fn is_enabled(self, cx: &App) -> bool {
2911        if cx.is_staff() {
2912            return true;
2913        }
2914        match self {
2915            Self::Enabled => true,
2916            Self::NotAnswered | Self::Disabled => false,
2917        }
2918    }
2919
2920    #[must_use]
2921    pub fn toggle(&self) -> DataCollectionChoice {
2922        match self {
2923            Self::Enabled => Self::Disabled,
2924            Self::Disabled => Self::Enabled,
2925            Self::NotAnswered => Self::Enabled,
2926        }
2927    }
2928}
2929
2930impl From<bool> for DataCollectionChoice {
2931    fn from(value: bool) -> Self {
2932        match value {
2933            true => DataCollectionChoice::Enabled,
2934            false => DataCollectionChoice::Disabled,
2935        }
2936    }
2937}
2938
2939struct ZedPredictUpsell;
2940
2941impl Dismissable for ZedPredictUpsell {
2942    const KEY: &'static str = "dismissed-edit-predict-upsell";
2943
2944    fn dismissed(cx: &App) -> bool {
2945        // To make this backwards compatible with older versions of Zed, we
2946        // check if the user has seen the previous Edit Prediction Onboarding
2947        // before, by checking the data collection choice which was written to
2948        // the database once the user clicked on "Accept and Enable"
2949        let kvp = KeyValueStore::global(cx);
2950        if kvp
2951            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2952            .log_err()
2953            .is_some_and(|s| s.is_some())
2954        {
2955            return true;
2956        }
2957
2958        kvp.read_kvp(Self::KEY)
2959            .log_err()
2960            .is_some_and(|s| s.is_some())
2961    }
2962}
2963
2964pub fn should_show_upsell_modal(cx: &App) -> bool {
2965    !ZedPredictUpsell::dismissed(cx)
2966}
2967
2968pub fn init(cx: &mut App) {
2969    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2970        workspace.register_action(
2971            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2972                ZedPredictModal::toggle(
2973                    workspace,
2974                    workspace.user_store().clone(),
2975                    workspace.client().clone(),
2976                    window,
2977                    cx,
2978                )
2979            },
2980        );
2981
2982        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2983            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2984                settings
2985                    .project
2986                    .all_languages
2987                    .edit_predictions
2988                    .get_or_insert_default()
2989                    .provider = Some(EditPredictionProvider::None)
2990            });
2991        });
2992        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2993            EditPredictionStore::try_global(cx).and_then(|store| {
2994                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2995            })
2996        }
2997
2998        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2999            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3000                copilot_ui::initiate_sign_in(copilot, window, cx);
3001            }
3002        });
3003        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3004            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3005                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3006            }
3007        });
3008        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3009            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3010                copilot_ui::initiate_sign_out(copilot, window, cx);
3011            }
3012        });
3013    })
3014    .detach();
3015}