edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::{AppState, Workspace};
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  79use crate::example_spec::ExampleSpec;
  80use crate::license_detection::LicenseDetectionWatcher;
  81use crate::mercury::Mercury;
  82use crate::onboarding_modal::ZedPredictModal;
  83pub use crate::prediction::EditPrediction;
  84pub use crate::prediction::EditPredictionId;
  85use crate::prediction::EditPredictionResult;
  86pub use crate::sweep_ai::SweepAi;
  87pub use capture_example::capture_example;
  88pub use language_model::ApiKeyState;
  89pub use telemetry_events::EditPredictionRating;
  90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  91
  92actions!(
  93    edit_prediction,
  94    [
  95        /// Resets the edit prediction onboarding state.
  96        ResetOnboarding,
  97        /// Clears the edit prediction history.
  98        ClearHistory,
  99    ]
 100);
 101
 102/// Maximum number of events to track.
 103const EVENT_COUNT_MAX: usize = 10;
 104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 112
 113pub struct EditPredictionJumpsFeatureFlag;
 114
 115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 116    const NAME: &'static str = "edit_prediction_jumps";
 117}
 118
 119#[derive(Clone)]
 120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 121
 122impl Global for EditPredictionStoreGlobal {}
 123
 124/// Configuration for using the raw Zeta2 endpoint.
 125/// When set, the client uses the raw endpoint and constructs the prompt itself.
 126/// The version is also used as the Baseten environment name (lowercased).
 127#[derive(Clone)]
 128pub struct Zeta2RawConfig {
 129    pub model_id: Option<String>,
 130    pub environment: Option<String>,
 131    pub format: ZetaFormat,
 132}
 133
 134pub struct EditPredictionStore {
 135    client: Arc<Client>,
 136    user_store: Entity<UserStore>,
 137    llm_token: LlmApiToken,
 138    _fetch_experiments_task: Task<()>,
 139    projects: HashMap<EntityId, ProjectState>,
 140    update_required: bool,
 141    edit_prediction_model: EditPredictionModel,
 142    zeta2_raw_config: Option<Zeta2RawConfig>,
 143    preferred_experiment: Option<String>,
 144    available_experiments: Vec<String>,
 145    pub sweep_ai: SweepAi,
 146    pub mercury: Mercury,
 147    data_collection_choice: DataCollectionChoice,
 148    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 149    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 150    shown_predictions: VecDeque<EditPrediction>,
 151    rated_predictions: HashSet<EditPredictionId>,
 152    #[cfg(test)]
 153    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 154}
 155
 156pub(crate) struct EditPredictionRejectionPayload {
 157    rejection: EditPredictionRejection,
 158    organization_id: Option<OrganizationId>,
 159}
 160
 161#[derive(Copy, Clone, PartialEq, Eq)]
 162pub enum EditPredictionModel {
 163    Zeta,
 164    Fim { format: EditPredictionPromptFormat },
 165    Sweep,
 166    Mercury,
 167}
 168
 169#[derive(Clone)]
 170pub struct EditPredictionModelInput {
 171    project: Entity<Project>,
 172    buffer: Entity<Buffer>,
 173    snapshot: BufferSnapshot,
 174    position: Anchor,
 175    events: Vec<Arc<zeta_prompt::Event>>,
 176    related_files: Vec<RelatedFile>,
 177    recent_paths: VecDeque<ProjectPath>,
 178    trigger: PredictEditsRequestTrigger,
 179    diagnostic_search_range: Range<Point>,
 180    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 181    can_collect_data: bool,
 182    is_open_source: bool,
 183    pub user_actions: Vec<UserActionRecord>,
 184}
 185
 186#[derive(Debug)]
 187pub enum DebugEvent {
 188    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 189    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 190    EditPredictionStarted(EditPredictionStartedDebugEvent),
 191    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 192}
 193
 194#[derive(Debug)]
 195pub struct ContextRetrievalStartedDebugEvent {
 196    pub project_entity_id: EntityId,
 197    pub timestamp: Instant,
 198    pub search_prompt: String,
 199}
 200
 201#[derive(Debug)]
 202pub struct ContextRetrievalFinishedDebugEvent {
 203    pub project_entity_id: EntityId,
 204    pub timestamp: Instant,
 205    pub metadata: Vec<(&'static str, SharedString)>,
 206}
 207
 208#[derive(Debug)]
 209pub struct EditPredictionStartedDebugEvent {
 210    pub buffer: WeakEntity<Buffer>,
 211    pub position: Anchor,
 212    pub prompt: Option<String>,
 213}
 214
 215#[derive(Debug)]
 216pub struct EditPredictionFinishedDebugEvent {
 217    pub buffer: WeakEntity<Buffer>,
 218    pub position: Anchor,
 219    pub model_output: Option<String>,
 220}
 221
 222const USER_ACTION_HISTORY_SIZE: usize = 16;
 223
 224#[derive(Clone, Debug)]
 225pub struct UserActionRecord {
 226    pub action_type: UserActionType,
 227    pub buffer_id: EntityId,
 228    pub line_number: u32,
 229    pub offset: usize,
 230    pub timestamp_epoch_ms: u64,
 231}
 232
 233#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 234pub enum UserActionType {
 235    InsertChar,
 236    InsertSelection,
 237    DeleteChar,
 238    DeleteSelection,
 239    CursorMovement,
 240}
 241
 242/// An event with associated metadata for reconstructing buffer state.
 243#[derive(Clone)]
 244pub struct StoredEvent {
 245    pub event: Arc<zeta_prompt::Event>,
 246    pub old_snapshot: TextBufferSnapshot,
 247    pub new_snapshot_version: clock::Global,
 248    pub total_edit_range: Range<Anchor>,
 249}
 250
 251impl StoredEvent {
 252    fn can_merge(
 253        &self,
 254        next_old_event: &StoredEvent,
 255        latest_snapshot: &TextBufferSnapshot,
 256        latest_edit_range: &Range<Anchor>,
 257    ) -> bool {
 258        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 259        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 260            return false;
 261        }
 262        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 263            return false;
 264        }
 265        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 266            return false;
 267        }
 268        if !latest_snapshot
 269            .version
 270            .observed_all(&next_old_event.new_snapshot_version)
 271        {
 272            return false;
 273        }
 274
 275        let a_is_predicted = matches!(
 276            self.event.as_ref(),
 277            zeta_prompt::Event::BufferChange {
 278                predicted: true,
 279                ..
 280            }
 281        );
 282        let b_is_predicted = matches!(
 283            next_old_event.event.as_ref(),
 284            zeta_prompt::Event::BufferChange {
 285                predicted: true,
 286                ..
 287            }
 288        );
 289
 290        // If events come from the same source (both predicted or both manual) then
 291        // we would have coalesced them already.
 292        if a_is_predicted == b_is_predicted {
 293            return false;
 294        }
 295
 296        let left_range = self.total_edit_range.to_point(latest_snapshot);
 297        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 298        let latest_range = latest_edit_range.to_point(latest_snapshot);
 299
 300        // Events near to the latest edit are not merged if their sources differ.
 301        if lines_between_ranges(&left_range, &latest_range)
 302            .min(lines_between_ranges(&right_range, &latest_range))
 303            <= CHANGE_GROUPING_LINE_SPAN
 304        {
 305            return false;
 306        }
 307
 308        // Events that are distant from each other are not merged.
 309        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 310            return false;
 311        }
 312
 313        true
 314    }
 315}
 316
 317fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 318    if left.start > right.end {
 319        return left.start.row - right.end.row;
 320    }
 321    if right.start > left.end {
 322        return right.start.row - left.end.row;
 323    }
 324    0
 325}
 326
 327struct ProjectState {
 328    events: VecDeque<StoredEvent>,
 329    last_event: Option<LastEvent>,
 330    recent_paths: VecDeque<ProjectPath>,
 331    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 332    current_prediction: Option<CurrentEditPrediction>,
 333    next_pending_prediction_id: usize,
 334    pending_predictions: ArrayVec<PendingPrediction, 2>,
 335    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 336    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 337    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 338    cancelled_predictions: HashSet<usize>,
 339    context: Entity<RelatedExcerptStore>,
 340    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 341    user_actions: VecDeque<UserActionRecord>,
 342    _subscriptions: [gpui::Subscription; 2],
 343    copilot: Option<Entity<Copilot>>,
 344}
 345
 346impl ProjectState {
 347    fn record_user_action(&mut self, action: UserActionRecord) {
 348        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 349            self.user_actions.pop_front();
 350        }
 351        self.user_actions.push_back(action);
 352    }
 353
 354    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 355        self.events
 356            .iter()
 357            .cloned()
 358            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 359                let (one, two) = event.split_by_pause();
 360                let one = one.finalize(&self.license_detection_watchers, cx);
 361                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 362                one.into_iter().chain(two)
 363            }))
 364            .collect()
 365    }
 366
 367    fn cancel_pending_prediction(
 368        &mut self,
 369        pending_prediction: PendingPrediction,
 370        cx: &mut Context<EditPredictionStore>,
 371    ) {
 372        self.cancelled_predictions.insert(pending_prediction.id);
 373
 374        if pending_prediction.drop_on_cancel {
 375            drop(pending_prediction.task);
 376        } else {
 377            cx.spawn(async move |this, cx| {
 378                let Some(prediction_id) = pending_prediction.task.await else {
 379                    return;
 380                };
 381
 382                this.update(cx, |this, cx| {
 383                    this.reject_prediction(
 384                        prediction_id,
 385                        EditPredictionRejectReason::Canceled,
 386                        false,
 387                        None,
 388                        None,
 389                        cx,
 390                    );
 391                })
 392                .ok();
 393            })
 394            .detach()
 395        }
 396    }
 397
 398    fn active_buffer(
 399        &self,
 400        project: &Entity<Project>,
 401        cx: &App,
 402    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 403        let project = project.read(cx);
 404        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 405        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 406        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 407        Some((active_buffer, registered_buffer.last_position))
 408    }
 409}
 410
 411#[derive(Debug, Clone)]
 412struct CurrentEditPrediction {
 413    pub requested_by: PredictionRequestedBy,
 414    pub prediction: EditPrediction,
 415    pub was_shown: bool,
 416    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 417    pub e2e_latency: std::time::Duration,
 418}
 419
 420impl CurrentEditPrediction {
 421    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 422        let Some(new_edits) = self
 423            .prediction
 424            .interpolate(&self.prediction.buffer.read(cx))
 425        else {
 426            return false;
 427        };
 428
 429        if self.prediction.buffer != old_prediction.prediction.buffer {
 430            return true;
 431        }
 432
 433        let Some(old_edits) = old_prediction
 434            .prediction
 435            .interpolate(&old_prediction.prediction.buffer.read(cx))
 436        else {
 437            return true;
 438        };
 439
 440        let requested_by_buffer_id = self.requested_by.buffer_id();
 441
 442        // This reduces the occurrence of UI thrash from replacing edits
 443        //
 444        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 445        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 446            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 447            && old_edits.len() == 1
 448            && new_edits.len() == 1
 449        {
 450            let (old_range, old_text) = &old_edits[0];
 451            let (new_range, new_text) = &new_edits[0];
 452            new_range == old_range && new_text.starts_with(old_text.as_ref())
 453        } else {
 454            true
 455        }
 456    }
 457}
 458
 459#[derive(Debug, Clone)]
 460enum PredictionRequestedBy {
 461    DiagnosticsUpdate,
 462    Buffer(EntityId),
 463}
 464
 465impl PredictionRequestedBy {
 466    pub fn buffer_id(&self) -> Option<EntityId> {
 467        match self {
 468            PredictionRequestedBy::DiagnosticsUpdate => None,
 469            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 470        }
 471    }
 472}
 473
 474const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 475
 476#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 477pub enum DiagnosticSearchScope {
 478    Local,
 479    Global,
 480}
 481
 482#[derive(Debug)]
 483struct PendingPrediction {
 484    id: usize,
 485    task: Task<Option<EditPredictionId>>,
 486    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 487    /// If false, the task is awaited to completion so rejection can be reported.
 488    drop_on_cancel: bool,
 489}
 490
 491/// A prediction from the perspective of a buffer.
 492#[derive(Debug)]
 493enum BufferEditPrediction<'a> {
 494    Local { prediction: &'a EditPrediction },
 495    Jump { prediction: &'a EditPrediction },
 496}
 497
 498#[cfg(test)]
 499impl std::ops::Deref for BufferEditPrediction<'_> {
 500    type Target = EditPrediction;
 501
 502    fn deref(&self) -> &Self::Target {
 503        match self {
 504            BufferEditPrediction::Local { prediction } => prediction,
 505            BufferEditPrediction::Jump { prediction } => prediction,
 506        }
 507    }
 508}
 509
 510#[derive(Clone)]
 511
 512struct PendingSettledPrediction {
 513    request_id: EditPredictionId,
 514    editable_anchor_range: Range<Anchor>,
 515    example: Option<ExampleSpec>,
 516    enqueued_at: Instant,
 517    last_edit_at: Instant,
 518    e2e_latency: std::time::Duration,
 519}
 520
 521struct RegisteredBuffer {
 522    file: Option<Arc<dyn File>>,
 523    snapshot: TextBufferSnapshot,
 524    pending_predictions: Vec<PendingSettledPrediction>,
 525    last_position: Option<Anchor>,
 526    _subscriptions: [gpui::Subscription; 2],
 527}
 528
 529#[derive(Clone)]
 530struct LastEvent {
 531    old_snapshot: TextBufferSnapshot,
 532    new_snapshot: TextBufferSnapshot,
 533    old_file: Option<Arc<dyn File>>,
 534    new_file: Option<Arc<dyn File>>,
 535    latest_edit_range: Range<Anchor>,
 536    total_edit_range: Range<Anchor>,
 537    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 538    predicted: bool,
 539    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 540    last_edit_time: Option<Instant>,
 541}
 542
 543impl LastEvent {
 544    pub fn finalize(
 545        &self,
 546        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 547        cx: &App,
 548    ) -> Option<StoredEvent> {
 549        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 550        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 551
 552        let in_open_source_repo =
 553            [self.new_file.as_ref(), self.old_file.as_ref()]
 554                .iter()
 555                .all(|file| {
 556                    file.is_some_and(|file| {
 557                        license_detection_watchers
 558                            .get(&file.worktree_id(cx))
 559                            .is_some_and(|watcher| watcher.is_project_open_source())
 560                    })
 561                });
 562
 563        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 564            &self.old_snapshot,
 565            &self.new_snapshot,
 566            &self.total_edit_range,
 567        )?;
 568
 569        if path == old_path && diff.is_empty() {
 570            None
 571        } else {
 572            Some(StoredEvent {
 573                event: Arc::new(zeta_prompt::Event::BufferChange {
 574                    old_path,
 575                    path,
 576                    diff,
 577                    in_open_source_repo,
 578                    predicted: self.predicted,
 579                }),
 580                old_snapshot: self.old_snapshot.clone(),
 581                new_snapshot_version: self.new_snapshot.version.clone(),
 582                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 583                    ..self.new_snapshot.anchor_before(edit_range.end),
 584            })
 585        }
 586    }
 587
 588    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 589        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 590            return (self.clone(), None);
 591        };
 592
 593        let total_edit_range_before_pause = self
 594            .total_edit_range_at_last_pause_boundary
 595            .clone()
 596            .unwrap_or_else(|| self.total_edit_range.clone());
 597
 598        let Some(total_edit_range_after_pause) =
 599            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 600        else {
 601            return (self.clone(), None);
 602        };
 603
 604        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 605        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 606
 607        let before = LastEvent {
 608            old_snapshot: self.old_snapshot.clone(),
 609            new_snapshot: boundary_snapshot.clone(),
 610            old_file: self.old_file.clone(),
 611            new_file: self.new_file.clone(),
 612            latest_edit_range: latest_edit_range_before_pause,
 613            total_edit_range: total_edit_range_before_pause,
 614            total_edit_range_at_last_pause_boundary: None,
 615            predicted: self.predicted,
 616            snapshot_after_last_editing_pause: None,
 617            last_edit_time: self.last_edit_time,
 618        };
 619
 620        let after = LastEvent {
 621            old_snapshot: boundary_snapshot.clone(),
 622            new_snapshot: self.new_snapshot.clone(),
 623            old_file: self.old_file.clone(),
 624            new_file: self.new_file.clone(),
 625            latest_edit_range: latest_edit_range_after_pause,
 626            total_edit_range: total_edit_range_after_pause,
 627            total_edit_range_at_last_pause_boundary: None,
 628            predicted: self.predicted,
 629            snapshot_after_last_editing_pause: None,
 630            last_edit_time: self.last_edit_time,
 631        };
 632
 633        (before, Some(after))
 634    }
 635}
 636
 637fn compute_total_edit_range_between_snapshots(
 638    old_snapshot: &TextBufferSnapshot,
 639    new_snapshot: &TextBufferSnapshot,
 640) -> Option<Range<Anchor>> {
 641    let edits: Vec<Edit<usize>> = new_snapshot
 642        .edits_since::<usize>(&old_snapshot.version)
 643        .collect();
 644
 645    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 646    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 647    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 648
 649    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 650}
 651
 652fn compute_old_range_for_new_range(
 653    old_snapshot: &TextBufferSnapshot,
 654    new_snapshot: &TextBufferSnapshot,
 655    total_edit_range: &Range<Anchor>,
 656) -> Option<Range<Point>> {
 657    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 658    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 659
 660    let edits: Vec<Edit<usize>> = new_snapshot
 661        .edits_since::<usize>(&old_snapshot.version)
 662        .collect();
 663    let mut old_start_offset = None;
 664    let mut old_end_offset = None;
 665    let mut delta: isize = 0;
 666
 667    for edit in &edits {
 668        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 669            old_start_offset = Some(if new_start_offset < edit.new.start {
 670                new_start_offset.checked_add_signed(-delta)?
 671            } else {
 672                edit.old.start
 673            });
 674        }
 675
 676        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 677            old_end_offset = Some(if new_end_offset < edit.new.start {
 678                new_end_offset.checked_add_signed(-delta)?
 679            } else {
 680                edit.old.end
 681            });
 682        }
 683
 684        delta += edit.new.len() as isize - edit.old.len() as isize;
 685    }
 686
 687    let old_start_offset =
 688        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 689    let old_end_offset =
 690        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 691
 692    Some(
 693        old_snapshot.offset_to_point(old_start_offset)
 694            ..old_snapshot.offset_to_point(old_end_offset),
 695    )
 696}
 697
 698fn compute_diff_between_snapshots_in_range(
 699    old_snapshot: &TextBufferSnapshot,
 700    new_snapshot: &TextBufferSnapshot,
 701    total_edit_range: &Range<Anchor>,
 702) -> Option<(String, Range<Point>)> {
 703    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 704    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 705    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 706    let old_start_point = old_range.start;
 707    let old_end_point = old_range.end;
 708
 709    const CONTEXT_LINES: u32 = 3;
 710
 711    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 712    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 713    let old_context_end_row =
 714        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 715    let new_context_end_row =
 716        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 717
 718    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 719    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 720    let old_end_line_offset = old_snapshot
 721        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 722    let new_end_line_offset = new_snapshot
 723        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 724    let old_edit_range = old_start_line_offset..old_end_line_offset;
 725    let new_edit_range = new_start_line_offset..new_end_line_offset;
 726
 727    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 728    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 729
 730    let diff = language::unified_diff_with_offsets(
 731        &old_region_text,
 732        &new_region_text,
 733        old_context_start_row,
 734        new_context_start_row,
 735    );
 736
 737    Some((diff, new_start_point..new_end_point))
 738}
 739
 740fn buffer_path_with_id_fallback(
 741    file: Option<&Arc<dyn File>>,
 742    snapshot: &TextBufferSnapshot,
 743    cx: &App,
 744) -> Arc<Path> {
 745    if let Some(file) = file {
 746        file.full_path(cx).into()
 747    } else {
 748        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 749    }
 750}
 751
 752impl EditPredictionStore {
 753    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 754        cx.try_global::<EditPredictionStoreGlobal>()
 755            .map(|global| global.0.clone())
 756    }
 757
 758    pub fn global(
 759        client: &Arc<Client>,
 760        user_store: &Entity<UserStore>,
 761        cx: &mut App,
 762    ) -> Entity<Self> {
 763        cx.try_global::<EditPredictionStoreGlobal>()
 764            .map(|global| global.0.clone())
 765            .unwrap_or_else(|| {
 766                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 767                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 768                ep_store
 769            })
 770    }
 771
 772    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 773        let data_collection_choice = Self::load_data_collection_choice();
 774
 775        let llm_token = LlmApiToken::global(cx);
 776
 777        let (reject_tx, reject_rx) = mpsc::unbounded();
 778        cx.background_spawn({
 779            let client = client.clone();
 780            let llm_token = llm_token.clone();
 781            let app_version = AppVersion::global(cx);
 782            let background_executor = cx.background_executor().clone();
 783            async move {
 784                Self::handle_rejected_predictions(
 785                    reject_rx,
 786                    client,
 787                    llm_token,
 788                    app_version,
 789                    background_executor,
 790                )
 791                .await
 792            }
 793        })
 794        .detach();
 795
 796        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 797        cx.spawn(async move |this, cx| {
 798            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 799        })
 800        .detach();
 801
 802        let mut current_user = user_store.read(cx).watch_current_user();
 803        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 804            while current_user.borrow().is_none() {
 805                current_user.next().await;
 806            }
 807
 808            this.update(cx, |this, cx| {
 809                if cx.is_staff() {
 810                    this.refresh_available_experiments(cx);
 811                }
 812            })
 813            .log_err();
 814        });
 815
 816        let this = Self {
 817            projects: HashMap::default(),
 818            client,
 819            user_store,
 820            llm_token,
 821            _fetch_experiments_task: fetch_experiments_task,
 822            update_required: false,
 823            edit_prediction_model: EditPredictionModel::Zeta,
 824            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 825            preferred_experiment: None,
 826            available_experiments: Vec::new(),
 827            sweep_ai: SweepAi::new(cx),
 828            mercury: Mercury::new(cx),
 829
 830            data_collection_choice,
 831            reject_predictions_tx: reject_tx,
 832            settled_predictions_tx,
 833            rated_predictions: Default::default(),
 834            shown_predictions: Default::default(),
 835            #[cfg(test)]
 836            settled_event_callback: None,
 837        };
 838
 839        this
 840    }
 841
 842    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 843        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 844        let format = ZetaFormat::parse(&version_str).ok()?;
 845        let model_id = env::var("ZED_ZETA_MODEL").ok();
 846        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 847        Some(Zeta2RawConfig {
 848            model_id,
 849            environment,
 850            format,
 851        })
 852    }
 853
 854    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 855        self.edit_prediction_model = model;
 856    }
 857
 858    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 859        self.zeta2_raw_config = Some(config);
 860    }
 861
 862    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 863        self.zeta2_raw_config.as_ref()
 864    }
 865
 866    pub fn preferred_experiment(&self) -> Option<&str> {
 867        self.preferred_experiment.as_deref()
 868    }
 869
 870    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 871        self.preferred_experiment = experiment;
 872    }
 873
 874    pub fn available_experiments(&self) -> &[String] {
 875        &self.available_experiments
 876    }
 877
 878    pub fn active_experiment(&self) -> Option<&str> {
 879        self.preferred_experiment.as_deref().or_else(|| {
 880            self.shown_predictions
 881                .iter()
 882                .find_map(|p| p.model_version.as_ref())
 883                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 884        })
 885    }
 886
 887    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 888        let client = self.client.clone();
 889        let llm_token = self.llm_token.clone();
 890        let app_version = AppVersion::global(cx);
 891        let organization_id = self
 892            .user_store
 893            .read(cx)
 894            .current_organization()
 895            .map(|organization| organization.id.clone());
 896
 897        cx.spawn(async move |this, cx| {
 898            let experiments = cx
 899                .background_spawn(async move {
 900                    let http_client = client.http_client();
 901                    let token = llm_token.acquire(&client, organization_id).await?;
 902                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 903                    let request = http_client::Request::builder()
 904                        .method(Method::GET)
 905                        .uri(url.as_ref())
 906                        .header("Authorization", format!("Bearer {}", token))
 907                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 908                        .body(Default::default())?;
 909                    let mut response = http_client.send(request).await?;
 910                    if response.status().is_success() {
 911                        let mut body = Vec::new();
 912                        response.body_mut().read_to_end(&mut body).await?;
 913                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 914                        Ok(experiments)
 915                    } else {
 916                        let mut body = String::new();
 917                        response.body_mut().read_to_string(&mut body).await?;
 918                        anyhow::bail!(
 919                            "Failed to fetch experiments: {:?}\nBody: {}",
 920                            response.status(),
 921                            body
 922                        );
 923                    }
 924                })
 925                .await?;
 926            this.update(cx, |this, cx| {
 927                this.available_experiments = experiments;
 928                cx.notify();
 929            })?;
 930            anyhow::Ok(())
 931        })
 932        .detach_and_log_err(cx);
 933    }
 934
 935    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 936        use ui::IconName;
 937        match self.edit_prediction_model {
 938            EditPredictionModel::Sweep => {
 939                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 940                    .with_disabled(IconName::SweepAiDisabled)
 941                    .with_up(IconName::SweepAiUp)
 942                    .with_down(IconName::SweepAiDown)
 943                    .with_error(IconName::SweepAiError)
 944            }
 945            EditPredictionModel::Mercury => {
 946                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 947            }
 948            EditPredictionModel::Zeta => {
 949                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 950                    .with_disabled(IconName::ZedPredictDisabled)
 951                    .with_up(IconName::ZedPredictUp)
 952                    .with_down(IconName::ZedPredictDown)
 953                    .with_error(IconName::ZedPredictError)
 954            }
 955            EditPredictionModel::Fim { .. } => {
 956                let settings = &all_language_settings(None, cx).edit_predictions;
 957                match settings.provider {
 958                    EditPredictionProvider::Ollama => {
 959                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 960                    }
 961                    _ => {
 962                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 963                    }
 964                }
 965            }
 966        }
 967    }
 968
 969    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 970        self.sweep_ai.api_token.read(cx).has_key()
 971    }
 972
 973    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 974        self.mercury.api_token.read(cx).has_key()
 975    }
 976
 977    pub fn mercury_has_payment_required_error(&self) -> bool {
 978        self.mercury.has_payment_required_error()
 979    }
 980
 981    pub fn clear_history(&mut self) {
 982        for project_state in self.projects.values_mut() {
 983            project_state.events.clear();
 984            project_state.last_event.take();
 985        }
 986    }
 987
 988    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 989        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 990            project_state.events.clear();
 991            project_state.last_event.take();
 992        }
 993    }
 994
 995    pub fn edit_history_for_project(
 996        &self,
 997        project: &Entity<Project>,
 998        cx: &App,
 999    ) -> Vec<StoredEvent> {
1000        self.projects
1001            .get(&project.entity_id())
1002            .map(|project_state| project_state.events(cx))
1003            .unwrap_or_default()
1004    }
1005
1006    pub fn context_for_project<'a>(
1007        &'a self,
1008        project: &Entity<Project>,
1009        cx: &'a mut App,
1010    ) -> Vec<RelatedFile> {
1011        self.projects
1012            .get(&project.entity_id())
1013            .map(|project_state| {
1014                project_state.context.update(cx, |context, cx| {
1015                    context
1016                        .related_files_with_buffers(cx)
1017                        .map(|(mut related_file, buffer)| {
1018                            related_file.in_open_source_repo = buffer
1019                                .read(cx)
1020                                .file()
1021                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1022                            related_file
1023                        })
1024                        .collect()
1025                })
1026            })
1027            .unwrap_or_default()
1028    }
1029
1030    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1031        self.projects
1032            .get(&project.entity_id())
1033            .and_then(|project| project.copilot.clone())
1034    }
1035
1036    pub fn start_copilot_for_project(
1037        &mut self,
1038        project: &Entity<Project>,
1039        cx: &mut Context<Self>,
1040    ) -> Option<Entity<Copilot>> {
1041        if DisableAiSettings::get(None, cx).disable_ai {
1042            return None;
1043        }
1044        let state = self.get_or_init_project(project, cx);
1045
1046        if state.copilot.is_some() {
1047            return state.copilot.clone();
1048        }
1049        let _project = project.clone();
1050        let project = project.read(cx);
1051
1052        let node = project.node_runtime().cloned();
1053        if let Some(node) = node {
1054            let next_id = project.languages().next_language_server_id();
1055            let fs = project.fs().clone();
1056
1057            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1058            state.copilot = Some(copilot.clone());
1059            Some(copilot)
1060        } else {
1061            None
1062        }
1063    }
1064
1065    pub fn context_for_project_with_buffers<'a>(
1066        &'a self,
1067        project: &Entity<Project>,
1068        cx: &'a mut App,
1069    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1070        self.projects
1071            .get(&project.entity_id())
1072            .map(|project| {
1073                project.context.update(cx, |context, cx| {
1074                    context.related_files_with_buffers(cx).collect()
1075                })
1076            })
1077            .unwrap_or_default()
1078    }
1079
1080    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1081        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1082            self.user_store.read(cx).edit_prediction_usage()
1083        } else {
1084            None
1085        }
1086    }
1087
1088    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1089        self.get_or_init_project(project, cx);
1090    }
1091
1092    pub fn register_buffer(
1093        &mut self,
1094        buffer: &Entity<Buffer>,
1095        project: &Entity<Project>,
1096        cx: &mut Context<Self>,
1097    ) {
1098        let project_state = self.get_or_init_project(project, cx);
1099        Self::register_buffer_impl(project_state, buffer, project, cx);
1100    }
1101
1102    fn get_or_init_project(
1103        &mut self,
1104        project: &Entity<Project>,
1105        cx: &mut Context<Self>,
1106    ) -> &mut ProjectState {
1107        let entity_id = project.entity_id();
1108        self.projects
1109            .entry(entity_id)
1110            .or_insert_with(|| ProjectState {
1111                context: {
1112                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1113                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1114                        this.handle_excerpt_store_event(entity_id, event);
1115                    })
1116                    .detach();
1117                    related_excerpt_store
1118                },
1119                events: VecDeque::new(),
1120                last_event: None,
1121                recent_paths: VecDeque::new(),
1122                debug_tx: None,
1123                registered_buffers: HashMap::default(),
1124                current_prediction: None,
1125                cancelled_predictions: HashSet::default(),
1126                pending_predictions: ArrayVec::new(),
1127                next_pending_prediction_id: 0,
1128                last_edit_prediction_refresh: None,
1129                last_jump_prediction_refresh: None,
1130                license_detection_watchers: HashMap::default(),
1131                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1132                _subscriptions: [
1133                    cx.subscribe(&project, Self::handle_project_event),
1134                    cx.observe_release(&project, move |this, _, cx| {
1135                        this.projects.remove(&entity_id);
1136                        cx.notify();
1137                    }),
1138                ],
1139                copilot: None,
1140            })
1141    }
1142
1143    pub fn remove_project(&mut self, project: &Entity<Project>) {
1144        self.projects.remove(&project.entity_id());
1145    }
1146
1147    fn handle_excerpt_store_event(
1148        &mut self,
1149        project_entity_id: EntityId,
1150        event: &RelatedExcerptStoreEvent,
1151    ) {
1152        if let Some(project_state) = self.projects.get(&project_entity_id) {
1153            if let Some(debug_tx) = project_state.debug_tx.clone() {
1154                match event {
1155                    RelatedExcerptStoreEvent::StartedRefresh => {
1156                        debug_tx
1157                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1158                                ContextRetrievalStartedDebugEvent {
1159                                    project_entity_id: project_entity_id,
1160                                    timestamp: Instant::now(),
1161                                    search_prompt: String::new(),
1162                                },
1163                            ))
1164                            .ok();
1165                    }
1166                    RelatedExcerptStoreEvent::FinishedRefresh {
1167                        cache_hit_count,
1168                        cache_miss_count,
1169                        mean_definition_latency,
1170                        max_definition_latency,
1171                    } => {
1172                        debug_tx
1173                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1174                                ContextRetrievalFinishedDebugEvent {
1175                                    project_entity_id: project_entity_id,
1176                                    timestamp: Instant::now(),
1177                                    metadata: vec![
1178                                        (
1179                                            "Cache Hits",
1180                                            format!(
1181                                                "{}/{}",
1182                                                cache_hit_count,
1183                                                cache_hit_count + cache_miss_count
1184                                            )
1185                                            .into(),
1186                                        ),
1187                                        (
1188                                            "Max LSP Time",
1189                                            format!("{} ms", max_definition_latency.as_millis())
1190                                                .into(),
1191                                        ),
1192                                        (
1193                                            "Mean LSP Time",
1194                                            format!("{} ms", mean_definition_latency.as_millis())
1195                                                .into(),
1196                                        ),
1197                                    ],
1198                                },
1199                            ))
1200                            .ok();
1201                    }
1202                }
1203            }
1204        }
1205    }
1206
1207    pub fn debug_info(
1208        &mut self,
1209        project: &Entity<Project>,
1210        cx: &mut Context<Self>,
1211    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1212        let project_state = self.get_or_init_project(project, cx);
1213        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1214        project_state.debug_tx = Some(debug_watch_tx);
1215        debug_watch_rx
1216    }
1217
1218    fn handle_project_event(
1219        &mut self,
1220        project: Entity<Project>,
1221        event: &project::Event,
1222        cx: &mut Context<Self>,
1223    ) {
1224        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1225            return;
1226        }
1227        // TODO [zeta2] init with recent paths
1228        match event {
1229            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1230                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1231                    return;
1232                };
1233                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1234                if let Some(path) = path {
1235                    if let Some(ix) = project_state
1236                        .recent_paths
1237                        .iter()
1238                        .position(|probe| probe == &path)
1239                    {
1240                        project_state.recent_paths.remove(ix);
1241                    }
1242                    project_state.recent_paths.push_front(path);
1243                }
1244            }
1245            project::Event::DiagnosticsUpdated { .. } => {
1246                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1247                    self.refresh_prediction_from_diagnostics(
1248                        project,
1249                        DiagnosticSearchScope::Global,
1250                        cx,
1251                    );
1252                }
1253            }
1254            _ => (),
1255        }
1256    }
1257
1258    fn register_buffer_impl<'a>(
1259        project_state: &'a mut ProjectState,
1260        buffer: &Entity<Buffer>,
1261        project: &Entity<Project>,
1262        cx: &mut Context<Self>,
1263    ) -> &'a mut RegisteredBuffer {
1264        let buffer_id = buffer.entity_id();
1265
1266        if let Some(file) = buffer.read(cx).file() {
1267            let worktree_id = file.worktree_id(cx);
1268            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1269                project_state
1270                    .license_detection_watchers
1271                    .entry(worktree_id)
1272                    .or_insert_with(|| {
1273                        let project_entity_id = project.entity_id();
1274                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1275                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1276                            else {
1277                                return;
1278                            };
1279                            project_state
1280                                .license_detection_watchers
1281                                .remove(&worktree_id);
1282                        })
1283                        .detach();
1284                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1285                    });
1286            }
1287        }
1288
1289        match project_state.registered_buffers.entry(buffer_id) {
1290            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1291            hash_map::Entry::Vacant(entry) => {
1292                let buf = buffer.read(cx);
1293                let snapshot = buf.text_snapshot();
1294                let file = buf.file().cloned();
1295                let project_entity_id = project.entity_id();
1296                entry.insert(RegisteredBuffer {
1297                    snapshot,
1298                    file,
1299                    last_position: None,
1300                    pending_predictions: Vec::new(),
1301                    _subscriptions: [
1302                        cx.subscribe(buffer, {
1303                            let project = project.downgrade();
1304                            move |this, buffer, event, cx| {
1305                                if let language::BufferEvent::Edited { is_local } = event
1306                                    && let Some(project) = project.upgrade()
1307                                {
1308                                    this.report_changes_for_buffer(
1309                                        &buffer, &project, false, *is_local, cx,
1310                                    );
1311                                }
1312                            }
1313                        }),
1314                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1315                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1316                            else {
1317                                return;
1318                            };
1319                            project_state.registered_buffers.remove(&buffer_id);
1320                        }),
1321                    ],
1322                })
1323            }
1324        }
1325    }
1326
1327    fn report_changes_for_buffer(
1328        &mut self,
1329        buffer: &Entity<Buffer>,
1330        project: &Entity<Project>,
1331        is_predicted: bool,
1332        is_local: bool,
1333        cx: &mut Context<Self>,
1334    ) {
1335        let project_state = self.get_or_init_project(project, cx);
1336        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1337
1338        let buf = buffer.read(cx);
1339        let new_file = buf.file().cloned();
1340        let new_snapshot = buf.text_snapshot();
1341        if new_snapshot.version == registered_buffer.snapshot.version {
1342            return;
1343        }
1344        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1345        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1346        let mut num_edits = 0usize;
1347        let mut total_deleted = 0usize;
1348        let mut total_inserted = 0usize;
1349        let mut edit_range: Option<Range<Anchor>> = None;
1350        let mut last_offset: Option<usize> = None;
1351        let now = cx.background_executor().now();
1352
1353        for (edit, anchor_range) in
1354            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1355        {
1356            num_edits += 1;
1357            total_deleted += edit.old.len();
1358            total_inserted += edit.new.len();
1359            edit_range = Some(match edit_range {
1360                None => anchor_range,
1361                Some(acc) => acc.start..anchor_range.end,
1362            });
1363            last_offset = Some(edit.new.end);
1364        }
1365
1366        let Some(edit_range) = edit_range else {
1367            return;
1368        };
1369
1370        for pending_prediction in &mut registered_buffer.pending_predictions {
1371            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1372                pending_prediction.last_edit_at = now;
1373            }
1374        }
1375
1376        let include_in_history = is_local
1377            || collaborator_edit_overlaps_locality_region(
1378                project_state,
1379                project,
1380                buffer,
1381                &buf.snapshot(),
1382                &edit_range,
1383                cx,
1384            );
1385
1386        if is_local {
1387            let action_type = match (total_deleted, total_inserted, num_edits) {
1388                (0, ins, n) if ins == n => UserActionType::InsertChar,
1389                (0, _, _) => UserActionType::InsertSelection,
1390                (del, 0, n) if del == n => UserActionType::DeleteChar,
1391                (_, 0, _) => UserActionType::DeleteSelection,
1392                (_, ins, n) if ins == n => UserActionType::InsertChar,
1393                (_, _, _) => UserActionType::InsertSelection,
1394            };
1395
1396            if let Some(offset) = last_offset {
1397                let point = new_snapshot.offset_to_point(offset);
1398                let timestamp_epoch_ms = SystemTime::now()
1399                    .duration_since(UNIX_EPOCH)
1400                    .map(|d| d.as_millis() as u64)
1401                    .unwrap_or(0);
1402                project_state.record_user_action(UserActionRecord {
1403                    action_type,
1404                    buffer_id: buffer.entity_id(),
1405                    line_number: point.row,
1406                    offset,
1407                    timestamp_epoch_ms,
1408                });
1409            }
1410        }
1411
1412        if !include_in_history {
1413            return;
1414        }
1415
1416        let events = &mut project_state.events;
1417
1418        if let Some(last_event) = project_state.last_event.as_mut() {
1419            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1420                == last_event.new_snapshot.remote_id()
1421                && old_snapshot.version == last_event.new_snapshot.version;
1422
1423            let prediction_source_changed = is_predicted != last_event.predicted;
1424
1425            let should_coalesce = is_next_snapshot_of_same_buffer
1426                && !prediction_source_changed
1427                && lines_between_ranges(
1428                    &edit_range.to_point(&new_snapshot),
1429                    &last_event.latest_edit_range.to_point(&new_snapshot),
1430                ) <= CHANGE_GROUPING_LINE_SPAN;
1431
1432            if should_coalesce {
1433                let pause_elapsed = last_event
1434                    .last_edit_time
1435                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1436                    .unwrap_or(false);
1437                if pause_elapsed {
1438                    last_event.snapshot_after_last_editing_pause =
1439                        Some(last_event.new_snapshot.clone());
1440                    last_event.total_edit_range_at_last_pause_boundary =
1441                        Some(last_event.total_edit_range.clone());
1442                }
1443
1444                last_event.latest_edit_range = edit_range.clone();
1445                last_event.total_edit_range =
1446                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1447                last_event.new_snapshot = new_snapshot;
1448                last_event.last_edit_time = Some(now);
1449                return;
1450            }
1451        }
1452
1453        if let Some(event) = project_state.last_event.take() {
1454            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1455                if events.len() + 1 >= EVENT_COUNT_MAX {
1456                    events.pop_front();
1457                }
1458                events.push_back(event);
1459            }
1460        }
1461
1462        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1463
1464        project_state.last_event = Some(LastEvent {
1465            old_file,
1466            new_file,
1467            old_snapshot,
1468            new_snapshot,
1469            latest_edit_range: edit_range.clone(),
1470            total_edit_range: edit_range,
1471            total_edit_range_at_last_pause_boundary: None,
1472            predicted: is_predicted,
1473            snapshot_after_last_editing_pause: None,
1474            last_edit_time: Some(now),
1475        });
1476    }
1477
1478    fn prediction_at(
1479        &mut self,
1480        buffer: &Entity<Buffer>,
1481        position: Option<language::Anchor>,
1482        project: &Entity<Project>,
1483        cx: &App,
1484    ) -> Option<BufferEditPrediction<'_>> {
1485        let project_state = self.projects.get_mut(&project.entity_id())?;
1486        if let Some(position) = position
1487            && let Some(buffer) = project_state
1488                .registered_buffers
1489                .get_mut(&buffer.entity_id())
1490        {
1491            buffer.last_position = Some(position);
1492        }
1493
1494        let CurrentEditPrediction {
1495            requested_by,
1496            prediction,
1497            ..
1498        } = project_state.current_prediction.as_ref()?;
1499
1500        if prediction.targets_buffer(buffer.read(cx)) {
1501            Some(BufferEditPrediction::Local { prediction })
1502        } else {
1503            let show_jump = match requested_by {
1504                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1505                    requested_by_buffer_id == &buffer.entity_id()
1506                }
1507                PredictionRequestedBy::DiagnosticsUpdate => true,
1508            };
1509
1510            if show_jump {
1511                Some(BufferEditPrediction::Jump { prediction })
1512            } else {
1513                None
1514            }
1515        }
1516    }
1517
1518    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1519        let Some(current_prediction) = self
1520            .projects
1521            .get_mut(&project.entity_id())
1522            .and_then(|project_state| project_state.current_prediction.take())
1523        else {
1524            return;
1525        };
1526
1527        self.report_changes_for_buffer(
1528            &current_prediction.prediction.buffer,
1529            project,
1530            true,
1531            true,
1532            cx,
1533        );
1534
1535        // can't hold &mut project_state ref across report_changes_for_buffer_call
1536        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1537            return;
1538        };
1539
1540        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1541            project_state.cancel_pending_prediction(pending_prediction, cx);
1542        }
1543
1544        match self.edit_prediction_model {
1545            EditPredictionModel::Sweep => {
1546                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1547            }
1548            EditPredictionModel::Mercury => {
1549                mercury::edit_prediction_accepted(
1550                    current_prediction.prediction.id,
1551                    self.client.http_client(),
1552                    cx,
1553                );
1554            }
1555            EditPredictionModel::Zeta => {
1556                let is_cloud = !matches!(
1557                    all_language_settings(None, cx).edit_predictions.provider,
1558                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1559                );
1560                if is_cloud {
1561                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1562                }
1563            }
1564            EditPredictionModel::Fim { .. } => {}
1565        }
1566    }
1567
1568    async fn handle_rejected_predictions(
1569        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1570        client: Arc<Client>,
1571        llm_token: LlmApiToken,
1572        app_version: Version,
1573        background_executor: BackgroundExecutor,
1574    ) {
1575        let mut rx = std::pin::pin!(rx.peekable());
1576        let mut batched = Vec::new();
1577
1578        while let Some(EditPredictionRejectionPayload {
1579            rejection,
1580            organization_id,
1581        }) = rx.next().await
1582        {
1583            batched.push(rejection);
1584
1585            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1586                select_biased! {
1587                    next = rx.as_mut().peek().fuse() => {
1588                        if next.is_some() {
1589                            continue;
1590                        }
1591                    }
1592                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1593                }
1594            }
1595
1596            let url = client
1597                .http_client()
1598                .build_zed_llm_url("/predict_edits/reject", &[])
1599                .unwrap();
1600
1601            let flush_count = batched
1602                .len()
1603                // in case items have accumulated after failure
1604                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1605            let start = batched.len() - flush_count;
1606
1607            let body = RejectEditPredictionsBodyRef {
1608                rejections: &batched[start..],
1609            };
1610
1611            let result = Self::send_api_request::<()>(
1612                |builder| {
1613                    let req = builder
1614                        .uri(url.as_ref())
1615                        .body(serde_json::to_string(&body)?.into());
1616                    anyhow::Ok(req?)
1617                },
1618                client.clone(),
1619                llm_token.clone(),
1620                organization_id,
1621                app_version.clone(),
1622                true,
1623            )
1624            .await;
1625
1626            if result.log_err().is_some() {
1627                batched.drain(start..);
1628            }
1629        }
1630    }
1631
1632    async fn run_settled_predictions_worker(
1633        this: WeakEntity<Self>,
1634        mut rx: UnboundedReceiver<Instant>,
1635        cx: &mut AsyncApp,
1636    ) {
1637        let mut next_wake_time: Option<Instant> = None;
1638        loop {
1639            let now = cx.background_executor().now();
1640            if let Some(wake_time) = next_wake_time.take() {
1641                cx.background_executor()
1642                    .timer(wake_time.duration_since(now))
1643                    .await;
1644            } else {
1645                let Some(new_enqueue_time) = rx.next().await else {
1646                    break;
1647                };
1648                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1649                while rx.next().now_or_never().flatten().is_some() {}
1650                continue;
1651            }
1652
1653            let Some(this) = this.upgrade() else {
1654                break;
1655            };
1656
1657            let now = cx.background_executor().now();
1658
1659            let mut oldest_edited_at = None;
1660
1661            this.update(cx, |this, _| {
1662                for (_, project_state) in this.projects.iter_mut() {
1663                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1664                        registered_buffer
1665                            .pending_predictions
1666                            .retain_mut(|pending_prediction| {
1667                                let age =
1668                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1669                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1670                                    return false;
1671                                }
1672
1673                                let quiet_for =
1674                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1675                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1676                                    let settled_editable_region = registered_buffer
1677                                        .snapshot
1678                                        .text_for_range(
1679                                            pending_prediction.editable_anchor_range.clone(),
1680                                        )
1681                                        .collect::<String>();
1682
1683                                    #[cfg(test)]
1684                                    if let Some(callback) = &this.settled_event_callback {
1685                                        callback(
1686                                            pending_prediction.request_id.clone(),
1687                                            settled_editable_region.clone(),
1688                                        );
1689                                    }
1690
1691                                    telemetry::event!(
1692                                        EDIT_PREDICTION_SETTLED_EVENT,
1693                                        request_id = pending_prediction.request_id.0.clone(),
1694                                        settled_editable_region,
1695                                        example = pending_prediction.example.take(),
1696                                        e2e_latency = pending_prediction.e2e_latency.as_millis(),
1697                                    );
1698
1699                                    return false;
1700                                }
1701
1702                                if oldest_edited_at
1703                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1704                                {
1705                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1706                                }
1707
1708                                true
1709                            });
1710                    }
1711                }
1712            });
1713
1714            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1715        }
1716    }
1717
1718    pub(crate) fn enqueue_settled_prediction(
1719        &mut self,
1720        request_id: EditPredictionId,
1721        project: &Entity<Project>,
1722        edited_buffer: &Entity<Buffer>,
1723        edited_buffer_snapshot: &BufferSnapshot,
1724        editable_offset_range: Range<usize>,
1725        example: Option<ExampleSpec>,
1726        e2e_latency: std::time::Duration,
1727        cx: &mut Context<Self>,
1728    ) {
1729        let this = &mut *self;
1730        let project_state = this.get_or_init_project(project, cx);
1731        if let Some(buffer) = project_state
1732            .registered_buffers
1733            .get_mut(&edited_buffer.entity_id())
1734        {
1735            let now = cx.background_executor().now();
1736            buffer.pending_predictions.push(PendingSettledPrediction {
1737                request_id: request_id,
1738                editable_anchor_range: edited_buffer_snapshot
1739                    .anchor_range_around(editable_offset_range),
1740                example,
1741                e2e_latency,
1742                enqueued_at: now,
1743                last_edit_at: now,
1744            });
1745            this.settled_predictions_tx.unbounded_send(now).ok();
1746        }
1747    }
1748
1749    fn reject_current_prediction(
1750        &mut self,
1751        reason: EditPredictionRejectReason,
1752        project: &Entity<Project>,
1753        cx: &App,
1754    ) {
1755        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1756            project_state.pending_predictions.clear();
1757            if let Some(prediction) = project_state.current_prediction.take() {
1758                let model_version = prediction.prediction.model_version.clone();
1759                self.reject_prediction(
1760                    prediction.prediction.id,
1761                    reason,
1762                    prediction.was_shown,
1763                    model_version,
1764                    Some(prediction.e2e_latency),
1765                    cx,
1766                );
1767            }
1768        };
1769    }
1770
1771    fn did_show_current_prediction(
1772        &mut self,
1773        project: &Entity<Project>,
1774        display_type: edit_prediction_types::SuggestionDisplayType,
1775        cx: &mut Context<Self>,
1776    ) {
1777        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1778            return;
1779        };
1780
1781        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1782            return;
1783        };
1784
1785        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1786        let previous_shown_with = current_prediction.shown_with;
1787
1788        if previous_shown_with.is_none() || !is_jump {
1789            current_prediction.shown_with = Some(display_type);
1790        }
1791
1792        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1793
1794        if is_first_non_jump_show {
1795            current_prediction.was_shown = true;
1796        }
1797
1798        let display_type_changed = previous_shown_with != Some(display_type);
1799
1800        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1801            sweep_ai::edit_prediction_shown(
1802                &self.sweep_ai,
1803                self.client.clone(),
1804                &current_prediction.prediction,
1805                display_type,
1806                cx,
1807            );
1808        }
1809
1810        if is_first_non_jump_show {
1811            self.shown_predictions
1812                .push_front(current_prediction.prediction.clone());
1813            if self.shown_predictions.len() > 50 {
1814                let completion = self.shown_predictions.pop_back().unwrap();
1815                self.rated_predictions.remove(&completion.id);
1816            }
1817        }
1818    }
1819
1820    fn reject_prediction(
1821        &mut self,
1822        prediction_id: EditPredictionId,
1823        reason: EditPredictionRejectReason,
1824        was_shown: bool,
1825        model_version: Option<String>,
1826        e2e_latency: Option<std::time::Duration>,
1827        cx: &App,
1828    ) {
1829        match self.edit_prediction_model {
1830            EditPredictionModel::Zeta => {
1831                let is_cloud = !matches!(
1832                    all_language_settings(None, cx).edit_predictions.provider,
1833                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1834                );
1835
1836                if is_cloud {
1837                    let organization_id = self
1838                        .user_store
1839                        .read(cx)
1840                        .current_organization()
1841                        .map(|organization| organization.id.clone());
1842
1843                    self.reject_predictions_tx
1844                        .unbounded_send(EditPredictionRejectionPayload {
1845                            rejection: EditPredictionRejection {
1846                                request_id: prediction_id.to_string(),
1847                                reason,
1848                                was_shown,
1849                                model_version,
1850                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1851                            },
1852                            organization_id,
1853                        })
1854                        .log_err();
1855                }
1856            }
1857            EditPredictionModel::Mercury => {
1858                mercury::edit_prediction_rejected(
1859                    prediction_id,
1860                    was_shown,
1861                    reason,
1862                    self.client.http_client(),
1863                    cx,
1864                );
1865            }
1866            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1867        }
1868    }
1869
1870    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1871        self.projects
1872            .get(&project.entity_id())
1873            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1874    }
1875
1876    pub fn refresh_prediction_from_buffer(
1877        &mut self,
1878        project: Entity<Project>,
1879        buffer: Entity<Buffer>,
1880        position: language::Anchor,
1881        cx: &mut Context<Self>,
1882    ) {
1883        self.queue_prediction_refresh(
1884            project.clone(),
1885            PredictEditsRequestTrigger::Other,
1886            buffer.entity_id(),
1887            cx,
1888            move |this, cx| {
1889                let Some(request_task) = this
1890                    .update(cx, |this, cx| {
1891                        this.request_prediction(
1892                            &project,
1893                            &buffer,
1894                            position,
1895                            PredictEditsRequestTrigger::Other,
1896                            cx,
1897                        )
1898                    })
1899                    .log_err()
1900                else {
1901                    return Task::ready(anyhow::Ok(None));
1902                };
1903
1904                cx.spawn(async move |_cx| {
1905                    request_task.await.map(|prediction_result| {
1906                        prediction_result.map(|prediction_result| {
1907                            (
1908                                prediction_result,
1909                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1910                            )
1911                        })
1912                    })
1913                })
1914            },
1915        )
1916    }
1917
1918    pub fn refresh_prediction_from_diagnostics(
1919        &mut self,
1920        project: Entity<Project>,
1921        scope: DiagnosticSearchScope,
1922        cx: &mut Context<Self>,
1923    ) {
1924        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1925            return;
1926        }
1927
1928        if currently_following(&project, cx) {
1929            return;
1930        }
1931
1932        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1933            return;
1934        };
1935
1936        // Prefer predictions from buffer
1937        if project_state.current_prediction.is_some() {
1938            log::debug!(
1939                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1940            );
1941            return;
1942        }
1943
1944        self.queue_prediction_refresh(
1945            project.clone(),
1946            PredictEditsRequestTrigger::Diagnostics,
1947            project.entity_id(),
1948            cx,
1949            move |this, cx| {
1950                let Some((active_buffer, snapshot, cursor_point)) = this
1951                    .read_with(cx, |this, cx| {
1952                        let project_state = this.projects.get(&project.entity_id())?;
1953                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1954                        let snapshot = buffer.read(cx).snapshot();
1955
1956                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1957                            return None;
1958                        }
1959
1960                        let cursor_point = position
1961                            .map(|pos| pos.to_point(&snapshot))
1962                            .unwrap_or_default();
1963
1964                        Some((buffer, snapshot, cursor_point))
1965                    })
1966                    .log_err()
1967                    .flatten()
1968                else {
1969                    return Task::ready(anyhow::Ok(None));
1970                };
1971
1972                cx.spawn(async move |cx| {
1973                    let diagnostic_search_range = match scope {
1974                        DiagnosticSearchScope::Local => {
1975                            let diagnostic_search_start =
1976                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1977                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1978                            Point::new(diagnostic_search_start, 0)
1979                                ..Point::new(diagnostic_search_end, 0)
1980                        }
1981                        DiagnosticSearchScope::Global => Default::default(),
1982                    };
1983
1984                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1985                        active_buffer,
1986                        &snapshot,
1987                        diagnostic_search_range,
1988                        cursor_point,
1989                        &project,
1990                        cx,
1991                    )
1992                    .await?
1993                    else {
1994                        return anyhow::Ok(None);
1995                    };
1996
1997                    let Some(prediction_result) = this
1998                        .update(cx, |this, cx| {
1999                            this.request_prediction(
2000                                &project,
2001                                &jump_buffer,
2002                                jump_position,
2003                                PredictEditsRequestTrigger::Diagnostics,
2004                                cx,
2005                            )
2006                        })?
2007                        .await?
2008                    else {
2009                        return anyhow::Ok(None);
2010                    };
2011
2012                    this.update(cx, |this, cx| {
2013                        Some((
2014                            if this
2015                                .get_or_init_project(&project, cx)
2016                                .current_prediction
2017                                .is_none()
2018                            {
2019                                prediction_result
2020                            } else {
2021                                EditPredictionResult {
2022                                    id: prediction_result.id,
2023                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2024                                    e2e_latency: prediction_result.e2e_latency,
2025                                }
2026                            },
2027                            PredictionRequestedBy::DiagnosticsUpdate,
2028                        ))
2029                    })
2030                })
2031            },
2032        );
2033    }
2034
2035    fn predictions_enabled_at(
2036        snapshot: &BufferSnapshot,
2037        position: Option<language::Anchor>,
2038        cx: &App,
2039    ) -> bool {
2040        let file = snapshot.file();
2041        let all_settings = all_language_settings(file, cx);
2042        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2043            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2044        {
2045            return false;
2046        }
2047
2048        if let Some(last_position) = position {
2049            let settings = snapshot.settings_at(last_position, cx);
2050
2051            if !settings.edit_predictions_disabled_in.is_empty()
2052                && let Some(scope) = snapshot.language_scope_at(last_position)
2053                && let Some(scope_name) = scope.override_name()
2054                && settings
2055                    .edit_predictions_disabled_in
2056                    .iter()
2057                    .any(|s| s == scope_name)
2058            {
2059                return false;
2060            }
2061        }
2062
2063        true
2064    }
2065
2066    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2067}
2068
2069fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2070    let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
2071        return false;
2072    };
2073
2074    app_state
2075        .workspace_store
2076        .read(cx)
2077        .workspaces()
2078        .filter_map(|workspace| workspace.upgrade())
2079        .any(|workspace| {
2080            workspace.read(cx).project().entity_id() == project.entity_id()
2081                && workspace
2082                    .read(cx)
2083                    .leader_for_pane(workspace.read(cx).active_pane())
2084                    .is_some()
2085        })
2086}
2087
2088fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2089    match provider {
2090        EditPredictionProvider::Zed
2091        | EditPredictionProvider::Sweep
2092        | EditPredictionProvider::Mercury
2093        | EditPredictionProvider::Ollama
2094        | EditPredictionProvider::OpenAiCompatibleApi
2095        | EditPredictionProvider::Experimental(_) => true,
2096        EditPredictionProvider::None
2097        | EditPredictionProvider::Copilot
2098        | EditPredictionProvider::Codestral => false,
2099    }
2100}
2101
2102impl EditPredictionStore {
2103    fn queue_prediction_refresh(
2104        &mut self,
2105        project: Entity<Project>,
2106        request_trigger: PredictEditsRequestTrigger,
2107        throttle_entity: EntityId,
2108        cx: &mut Context<Self>,
2109        do_refresh: impl FnOnce(
2110            WeakEntity<Self>,
2111            &mut AsyncApp,
2112        )
2113            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2114        + 'static,
2115    ) {
2116        fn select_throttle(
2117            project_state: &mut ProjectState,
2118            request_trigger: PredictEditsRequestTrigger,
2119        ) -> &mut Option<(EntityId, Instant)> {
2120            match request_trigger {
2121                PredictEditsRequestTrigger::Diagnostics => {
2122                    &mut project_state.last_jump_prediction_refresh
2123                }
2124                _ => &mut project_state.last_edit_prediction_refresh,
2125            }
2126        }
2127
2128        let (needs_acceptance_tracking, max_pending_predictions) =
2129            match all_language_settings(None, cx).edit_predictions.provider {
2130                EditPredictionProvider::Zed
2131                | EditPredictionProvider::Sweep
2132                | EditPredictionProvider::Mercury
2133                | EditPredictionProvider::Experimental(_) => (true, 2),
2134                EditPredictionProvider::Ollama => (false, 1),
2135                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2136                EditPredictionProvider::None
2137                | EditPredictionProvider::Copilot
2138                | EditPredictionProvider::Codestral => {
2139                    log::error!("queue_prediction_refresh called with non-store provider");
2140                    return;
2141                }
2142            };
2143
2144        let drop_on_cancel = !needs_acceptance_tracking;
2145        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2146        let project_state = self.get_or_init_project(&project, cx);
2147        let pending_prediction_id = project_state.next_pending_prediction_id;
2148        project_state.next_pending_prediction_id += 1;
2149        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2150
2151        let task = cx.spawn(async move |this, cx| {
2152            let throttle_wait = this
2153                .update(cx, |this, cx| {
2154                    let project_state = this.get_or_init_project(&project, cx);
2155                    let throttle = *select_throttle(project_state, request_trigger);
2156
2157                    throttle.and_then(|(last_entity, last_timestamp)| {
2158                        if throttle_entity != last_entity {
2159                            return None;
2160                        }
2161                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2162                    })
2163                })
2164                .ok()
2165                .flatten();
2166
2167            if let Some(timeout) = throttle_wait {
2168                cx.background_executor().timer(timeout).await;
2169            }
2170
2171            // If this task was cancelled before the throttle timeout expired,
2172            // do not perform a request. Also skip if another task already
2173            // proceeded since we were enqueued (duplicate).
2174            let mut is_cancelled = true;
2175            this.update(cx, |this, cx| {
2176                let project_state = this.get_or_init_project(&project, cx);
2177                let was_cancelled = project_state
2178                    .cancelled_predictions
2179                    .remove(&pending_prediction_id);
2180                if was_cancelled {
2181                    return;
2182                }
2183
2184                // Another request has been already sent since this was enqueued
2185                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2186                    return;
2187                }
2188
2189                let new_refresh = (throttle_entity, Instant::now());
2190                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2191                is_cancelled = false;
2192            })
2193            .ok();
2194            if is_cancelled {
2195                return None;
2196            }
2197
2198            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2199            let new_prediction_id = new_prediction_result
2200                .as_ref()
2201                .map(|(prediction, _)| prediction.id.clone());
2202
2203            // When a prediction completes, remove it from the pending list, and cancel
2204            // any pending predictions that were enqueued before it.
2205            this.update(cx, |this, cx| {
2206                let project_state = this.get_or_init_project(&project, cx);
2207
2208                let is_cancelled = project_state
2209                    .cancelled_predictions
2210                    .remove(&pending_prediction_id);
2211
2212                let new_current_prediction = if !is_cancelled
2213                    && let Some((prediction_result, requested_by)) = new_prediction_result
2214                {
2215                    match prediction_result.prediction {
2216                        Ok(prediction) => {
2217                            let new_prediction = CurrentEditPrediction {
2218                                requested_by,
2219                                prediction,
2220                                was_shown: false,
2221                                shown_with: None,
2222                                e2e_latency: prediction_result.e2e_latency,
2223                            };
2224
2225                            if let Some(current_prediction) =
2226                                project_state.current_prediction.as_ref()
2227                            {
2228                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2229                                {
2230                                    this.reject_current_prediction(
2231                                        EditPredictionRejectReason::Replaced,
2232                                        &project,
2233                                        cx,
2234                                    );
2235
2236                                    Some(new_prediction)
2237                                } else {
2238                                    this.reject_prediction(
2239                                        new_prediction.prediction.id,
2240                                        EditPredictionRejectReason::CurrentPreferred,
2241                                        false,
2242                                        new_prediction.prediction.model_version,
2243                                        Some(new_prediction.e2e_latency),
2244                                        cx,
2245                                    );
2246                                    None
2247                                }
2248                            } else {
2249                                Some(new_prediction)
2250                            }
2251                        }
2252                        Err(reject_reason) => {
2253                            this.reject_prediction(
2254                                prediction_result.id,
2255                                reject_reason,
2256                                false,
2257                                None,
2258                                Some(prediction_result.e2e_latency),
2259                                cx,
2260                            );
2261                            None
2262                        }
2263                    }
2264                } else {
2265                    None
2266                };
2267
2268                let project_state = this.get_or_init_project(&project, cx);
2269
2270                if let Some(new_prediction) = new_current_prediction {
2271                    project_state.current_prediction = Some(new_prediction);
2272                }
2273
2274                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2275                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2276                    if pending_prediction.id == pending_prediction_id {
2277                        pending_predictions.remove(ix);
2278                        for pending_prediction in pending_predictions.drain(0..ix) {
2279                            project_state.cancel_pending_prediction(pending_prediction, cx)
2280                        }
2281                        break;
2282                    }
2283                }
2284                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2285                cx.notify();
2286            })
2287            .ok();
2288
2289            new_prediction_id
2290        });
2291
2292        if project_state.pending_predictions.len() < max_pending_predictions {
2293            project_state.pending_predictions.push(PendingPrediction {
2294                id: pending_prediction_id,
2295                task,
2296                drop_on_cancel,
2297            });
2298        } else {
2299            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2300            project_state.pending_predictions.push(PendingPrediction {
2301                id: pending_prediction_id,
2302                task,
2303                drop_on_cancel,
2304            });
2305            project_state.cancel_pending_prediction(pending_prediction, cx);
2306        }
2307    }
2308
2309    pub fn request_prediction(
2310        &mut self,
2311        project: &Entity<Project>,
2312        active_buffer: &Entity<Buffer>,
2313        position: language::Anchor,
2314        trigger: PredictEditsRequestTrigger,
2315        cx: &mut Context<Self>,
2316    ) -> Task<Result<Option<EditPredictionResult>>> {
2317        self.request_prediction_internal(
2318            project.clone(),
2319            active_buffer.clone(),
2320            position,
2321            trigger,
2322            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2323            cx,
2324        )
2325    }
2326
2327    fn request_prediction_internal(
2328        &mut self,
2329        project: Entity<Project>,
2330        active_buffer: Entity<Buffer>,
2331        position: language::Anchor,
2332        trigger: PredictEditsRequestTrigger,
2333        allow_jump: bool,
2334        cx: &mut Context<Self>,
2335    ) -> Task<Result<Option<EditPredictionResult>>> {
2336        self.get_or_init_project(&project, cx);
2337        let project_state = self.projects.get(&project.entity_id()).unwrap();
2338        let stored_events = project_state.events(cx);
2339        let has_events = !stored_events.is_empty();
2340        let events: Vec<Arc<zeta_prompt::Event>> =
2341            stored_events.iter().map(|e| e.event.clone()).collect();
2342        let debug_tx = project_state.debug_tx.clone();
2343
2344        let snapshot = active_buffer.read(cx).snapshot();
2345        let cursor_point = position.to_point(&snapshot);
2346        let current_offset = position.to_offset(&snapshot);
2347
2348        let mut user_actions: Vec<UserActionRecord> =
2349            project_state.user_actions.iter().cloned().collect();
2350
2351        if let Some(last_action) = user_actions.last() {
2352            if last_action.buffer_id == active_buffer.entity_id()
2353                && current_offset != last_action.offset
2354            {
2355                let timestamp_epoch_ms = SystemTime::now()
2356                    .duration_since(UNIX_EPOCH)
2357                    .map(|d| d.as_millis() as u64)
2358                    .unwrap_or(0);
2359                user_actions.push(UserActionRecord {
2360                    action_type: UserActionType::CursorMovement,
2361                    buffer_id: active_buffer.entity_id(),
2362                    line_number: cursor_point.row,
2363                    offset: current_offset,
2364                    timestamp_epoch_ms,
2365                });
2366            }
2367        }
2368        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2369        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2370        let diagnostic_search_range =
2371            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2372
2373        let related_files = self.context_for_project(&project, cx);
2374
2375        let is_open_source = snapshot
2376            .file()
2377            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2378            && events.iter().all(|event| event.in_open_source_repo())
2379            && related_files.iter().all(|file| file.in_open_source_repo);
2380
2381        let can_collect_data = !cfg!(test)
2382            && is_open_source
2383            && self.is_data_collection_enabled(cx)
2384            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2385
2386        let recent_paths = project_state.recent_paths.clone();
2387
2388        let inputs = EditPredictionModelInput {
2389            project: project.clone(),
2390            buffer: active_buffer,
2391            snapshot,
2392            position,
2393            events,
2394            related_files,
2395            recent_paths,
2396            trigger,
2397            diagnostic_search_range: diagnostic_search_range,
2398            debug_tx,
2399            user_actions,
2400            can_collect_data,
2401            is_open_source,
2402        };
2403
2404        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2405
2406        let task = match self.edit_prediction_model {
2407            EditPredictionModel::Zeta => {
2408                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2409            }
2410            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2411            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2412            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2413        };
2414
2415        cx.spawn(async move |this, cx| {
2416            let prediction = task.await?;
2417
2418            // Only fall back to diagnostics-based prediction if we got a
2419            // the model had nothing to suggest for the buffer
2420            if prediction.is_none()
2421                && allow_jump
2422                && has_events
2423                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2424            {
2425                this.update(cx, |this, cx| {
2426                    this.refresh_prediction_from_diagnostics(
2427                        project,
2428                        DiagnosticSearchScope::Local,
2429                        cx,
2430                    );
2431                })?;
2432                return anyhow::Ok(None);
2433            }
2434
2435            Ok(prediction)
2436        })
2437    }
2438
2439    pub(crate) async fn next_diagnostic_location(
2440        active_buffer: Entity<Buffer>,
2441        active_buffer_snapshot: &BufferSnapshot,
2442        active_buffer_diagnostic_search_range: Range<Point>,
2443        active_buffer_cursor_point: Point,
2444        project: &Entity<Project>,
2445        cx: &mut AsyncApp,
2446    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2447        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2448            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2449            .flat_map(|(_, _, _, selections)| {
2450                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2451            })
2452            .collect();
2453
2454        let mut jump_location = active_buffer_snapshot
2455            .diagnostic_groups(None)
2456            .into_iter()
2457            .filter_map(|(_, group)| {
2458                let range = &group.entries[group.primary_ix]
2459                    .range
2460                    .to_point(&active_buffer_snapshot);
2461                if range.overlaps(&active_buffer_diagnostic_search_range) {
2462                    return None;
2463                }
2464                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2465                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2466                });
2467                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2468                    <= DIAGNOSTIC_LINES_RANGE;
2469                if near_collaborator && !near_local {
2470                    return None;
2471                }
2472                Some(range.start)
2473            })
2474            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2475            .map(|position| {
2476                (
2477                    active_buffer.clone(),
2478                    active_buffer_snapshot.anchor_before(position),
2479                )
2480            });
2481
2482        if jump_location.is_none() {
2483            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2484                let file = buffer.file()?;
2485
2486                Some(ProjectPath {
2487                    worktree_id: file.worktree_id(cx),
2488                    path: file.path().clone(),
2489                })
2490            });
2491
2492            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2493                project
2494                    .diagnostic_summaries(false, cx)
2495                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2496                    .map(|(path, _, _)| {
2497                        let shared_prefix = path
2498                            .path
2499                            .components()
2500                            .zip(
2501                                active_buffer_path
2502                                    .as_ref()
2503                                    .map(|p| p.path.components())
2504                                    .unwrap_or_default(),
2505                            )
2506                            .take_while(|(a, b)| a == b)
2507                            .count();
2508                        (path, shared_prefix)
2509                    })
2510                    .collect()
2511            });
2512
2513            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2514
2515            for (path, _) in candidates {
2516                let candidate_buffer = project
2517                    .update(cx, |project, cx| project.open_buffer(path, cx))
2518                    .await?;
2519
2520                let (has_collaborators, diagnostic_position) =
2521                    candidate_buffer.read_with(cx, |buffer, _cx| {
2522                        let snapshot = buffer.snapshot();
2523                        let has_collaborators = snapshot
2524                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2525                            .next()
2526                            .is_some();
2527                        let position = buffer
2528                            .buffer_diagnostics(None)
2529                            .into_iter()
2530                            .min_by_key(|entry| entry.diagnostic.severity)
2531                            .map(|entry| entry.range.start);
2532                        (has_collaborators, position)
2533                    });
2534
2535                if has_collaborators {
2536                    continue;
2537                }
2538
2539                if let Some(position) = diagnostic_position {
2540                    jump_location = Some((candidate_buffer, position));
2541                    break;
2542                }
2543            }
2544        }
2545
2546        anyhow::Ok(jump_location)
2547    }
2548
2549    async fn send_raw_llm_request(
2550        request: RawCompletionRequest,
2551        client: Arc<Client>,
2552        custom_url: Option<Arc<Url>>,
2553        llm_token: LlmApiToken,
2554        organization_id: Option<OrganizationId>,
2555        app_version: Version,
2556    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2557        let url = if let Some(custom_url) = custom_url {
2558            custom_url.as_ref().clone()
2559        } else {
2560            client
2561                .http_client()
2562                .build_zed_llm_url("/predict_edits/raw", &[])?
2563        };
2564
2565        Self::send_api_request(
2566            |builder| {
2567                let req = builder
2568                    .uri(url.as_ref())
2569                    .body(serde_json::to_string(&request)?.into());
2570                Ok(req?)
2571            },
2572            client,
2573            llm_token,
2574            organization_id,
2575            app_version,
2576            true,
2577        )
2578        .await
2579    }
2580
2581    pub(crate) async fn send_v3_request(
2582        input: ZetaPromptInput,
2583        client: Arc<Client>,
2584        llm_token: LlmApiToken,
2585        organization_id: Option<OrganizationId>,
2586        app_version: Version,
2587        trigger: PredictEditsRequestTrigger,
2588    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2589        let url = client
2590            .http_client()
2591            .build_zed_llm_url("/predict_edits/v3", &[])?;
2592
2593        let request = PredictEditsV3Request { input, trigger };
2594
2595        let json_bytes = serde_json::to_vec(&request)?;
2596        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2597
2598        Self::send_api_request(
2599            |builder| {
2600                let req = builder
2601                    .uri(url.as_ref())
2602                    .header("Content-Encoding", "zstd")
2603                    .body(compressed.clone().into());
2604                Ok(req?)
2605            },
2606            client,
2607            llm_token,
2608            organization_id,
2609            app_version,
2610            true,
2611        )
2612        .await
2613    }
2614
2615    async fn send_api_request<Res>(
2616        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2617        client: Arc<Client>,
2618        llm_token: LlmApiToken,
2619        organization_id: Option<OrganizationId>,
2620        app_version: Version,
2621        require_auth: bool,
2622    ) -> Result<(Res, Option<EditPredictionUsage>)>
2623    where
2624        Res: DeserializeOwned,
2625    {
2626        let http_client = client.http_client();
2627
2628        let mut token = if require_auth {
2629            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2630        } else {
2631            llm_token
2632                .acquire(&client, organization_id.clone())
2633                .await
2634                .ok()
2635        };
2636        let mut did_retry = false;
2637
2638        loop {
2639            let request_builder = http_client::Request::builder().method(Method::POST);
2640
2641            let mut request_builder = request_builder
2642                .header("Content-Type", "application/json")
2643                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2644
2645            // Only add Authorization header if we have a token
2646            if let Some(ref token_value) = token {
2647                request_builder =
2648                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2649            }
2650
2651            let request = build(request_builder)?;
2652
2653            let mut response = http_client.send(request).await?;
2654
2655            if let Some(minimum_required_version) = response
2656                .headers()
2657                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2658                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2659            {
2660                anyhow::ensure!(
2661                    app_version >= minimum_required_version,
2662                    ZedUpdateRequiredError {
2663                        minimum_version: minimum_required_version
2664                    }
2665                );
2666            }
2667
2668            if response.status().is_success() {
2669                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2670
2671                let mut body = Vec::new();
2672                response.body_mut().read_to_end(&mut body).await?;
2673                return Ok((serde_json::from_slice(&body)?, usage));
2674            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2675                did_retry = true;
2676                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2677            } else {
2678                let mut body = String::new();
2679                response.body_mut().read_to_string(&mut body).await?;
2680                anyhow::bail!(
2681                    "Request failed with status: {:?}\nBody: {}",
2682                    response.status(),
2683                    body
2684                );
2685            }
2686        }
2687    }
2688
2689    pub fn refresh_context(
2690        &mut self,
2691        project: &Entity<Project>,
2692        buffer: &Entity<language::Buffer>,
2693        cursor_position: language::Anchor,
2694        cx: &mut Context<Self>,
2695    ) {
2696        self.get_or_init_project(project, cx)
2697            .context
2698            .update(cx, |store, cx| {
2699                store.refresh(buffer.clone(), cursor_position, cx);
2700            });
2701    }
2702
2703    #[cfg(feature = "cli-support")]
2704    pub fn set_context_for_buffer(
2705        &mut self,
2706        project: &Entity<Project>,
2707        related_files: Vec<RelatedFile>,
2708        cx: &mut Context<Self>,
2709    ) {
2710        self.get_or_init_project(project, cx)
2711            .context
2712            .update(cx, |store, cx| {
2713                store.set_related_files(related_files, cx);
2714            });
2715    }
2716
2717    #[cfg(feature = "cli-support")]
2718    pub fn set_recent_paths_for_project(
2719        &mut self,
2720        project: &Entity<Project>,
2721        paths: impl IntoIterator<Item = project::ProjectPath>,
2722        cx: &mut Context<Self>,
2723    ) {
2724        let project_state = self.get_or_init_project(project, cx);
2725        project_state.recent_paths = paths.into_iter().collect();
2726    }
2727
2728    fn is_file_open_source(
2729        &self,
2730        project: &Entity<Project>,
2731        file: &Arc<dyn File>,
2732        cx: &App,
2733    ) -> bool {
2734        if !file.is_local() || file.is_private() {
2735            return false;
2736        }
2737        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2738            return false;
2739        };
2740        project_state
2741            .license_detection_watchers
2742            .get(&file.worktree_id(cx))
2743            .as_ref()
2744            .is_some_and(|watcher| watcher.is_project_open_source())
2745    }
2746
2747    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2748        self.data_collection_choice.is_enabled(cx)
2749    }
2750
2751    fn load_data_collection_choice() -> DataCollectionChoice {
2752        let choice = KEY_VALUE_STORE
2753            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2754            .log_err()
2755            .flatten();
2756
2757        match choice.as_deref() {
2758            Some("true") => DataCollectionChoice::Enabled,
2759            Some("false") => DataCollectionChoice::Disabled,
2760            Some(_) => {
2761                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2762                DataCollectionChoice::NotAnswered
2763            }
2764            None => DataCollectionChoice::NotAnswered,
2765        }
2766    }
2767
2768    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2769        self.data_collection_choice = self.data_collection_choice.toggle();
2770        let new_choice = self.data_collection_choice;
2771        let is_enabled = new_choice.is_enabled(cx);
2772        db::write_and_log(cx, move || {
2773            KEY_VALUE_STORE.write_kvp(
2774                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2775                is_enabled.to_string(),
2776            )
2777        });
2778    }
2779
2780    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2781        self.shown_predictions.iter()
2782    }
2783
2784    pub fn shown_completions_len(&self) -> usize {
2785        self.shown_predictions.len()
2786    }
2787
2788    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2789        self.rated_predictions.contains(id)
2790    }
2791
2792    pub fn rate_prediction(
2793        &mut self,
2794        prediction: &EditPrediction,
2795        rating: EditPredictionRating,
2796        feedback: String,
2797        cx: &mut Context<Self>,
2798    ) {
2799        let organization = self.user_store.read(cx).current_organization();
2800
2801        self.rated_predictions.insert(prediction.id.clone());
2802
2803        cx.background_spawn({
2804            let client = self.client.clone();
2805            let prediction_id = prediction.id.to_string();
2806            let inputs = serde_json::to_value(&prediction.inputs);
2807            let output = prediction
2808                .edit_preview
2809                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2810            async move {
2811                client
2812                    .cloud_client()
2813                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2814                        organization_id: organization.map(|organization| organization.id.clone()),
2815                        request_id: prediction_id,
2816                        rating: match rating {
2817                            EditPredictionRating::Positive => "positive".to_string(),
2818                            EditPredictionRating::Negative => "negative".to_string(),
2819                        },
2820                        inputs: inputs?,
2821                        output,
2822                        feedback,
2823                    })
2824                    .await?;
2825
2826                anyhow::Ok(())
2827            }
2828        })
2829        .detach_and_log_err(cx);
2830
2831        cx.notify();
2832    }
2833}
2834
2835fn collaborator_edit_overlaps_locality_region(
2836    project_state: &ProjectState,
2837    project: &Entity<Project>,
2838    buffer: &Entity<Buffer>,
2839    snapshot: &BufferSnapshot,
2840    edit_range: &Range<Anchor>,
2841    cx: &App,
2842) -> bool {
2843    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2844        return false;
2845    };
2846
2847    if active_buffer.entity_id() != buffer.entity_id() {
2848        return false;
2849    }
2850
2851    let locality_point_range = expand_context_syntactically_then_linewise(
2852        snapshot,
2853        (position..position).to_point(snapshot),
2854        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2855    );
2856    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2857
2858    edit_range.overlaps(&locality_anchor_range, snapshot)
2859}
2860
2861fn merge_trailing_events_if_needed(
2862    events: &mut VecDeque<StoredEvent>,
2863    end_snapshot: &TextBufferSnapshot,
2864    latest_snapshot: &TextBufferSnapshot,
2865    latest_edit_range: &Range<Anchor>,
2866) {
2867    if let Some(last_event) = events.back() {
2868        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2869            return;
2870        }
2871        if !latest_snapshot
2872            .version
2873            .observed_all(&last_event.new_snapshot_version)
2874        {
2875            return;
2876        }
2877    }
2878
2879    let mut next_old_event = None;
2880    let mut mergeable_count = 0;
2881    for old_event in events.iter().rev() {
2882        if let Some(next_old_event) = next_old_event
2883            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2884        {
2885            break;
2886        }
2887        mergeable_count += 1;
2888        next_old_event = Some(old_event);
2889    }
2890
2891    if mergeable_count <= 1 {
2892        return;
2893    }
2894
2895    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2896    let oldest_event = events_to_merge.peek().unwrap();
2897    let oldest_snapshot = oldest_event.old_snapshot.clone();
2898    let newest_snapshot = end_snapshot;
2899    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2900
2901    for event in events.range(events.len() - mergeable_count + 1..) {
2902        merged_edit_range =
2903            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2904    }
2905
2906    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2907        &oldest_snapshot,
2908        newest_snapshot,
2909        &merged_edit_range,
2910    ) {
2911        let merged_event = match oldest_event.event.as_ref() {
2912            zeta_prompt::Event::BufferChange {
2913                old_path,
2914                path,
2915                in_open_source_repo,
2916                ..
2917            } => StoredEvent {
2918                event: Arc::new(zeta_prompt::Event::BufferChange {
2919                    old_path: old_path.clone(),
2920                    path: path.clone(),
2921                    diff,
2922                    in_open_source_repo: *in_open_source_repo,
2923                    predicted: events_to_merge.all(|e| {
2924                        matches!(
2925                            e.event.as_ref(),
2926                            zeta_prompt::Event::BufferChange {
2927                                predicted: true,
2928                                ..
2929                            }
2930                        )
2931                    }),
2932                }),
2933                old_snapshot: oldest_snapshot.clone(),
2934                new_snapshot_version: newest_snapshot.version.clone(),
2935                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2936                    ..newest_snapshot.anchor_before(edit_range.end),
2937            },
2938        };
2939        events.truncate(events.len() - mergeable_count);
2940        events.push_back(merged_event);
2941    }
2942}
2943
2944fn merge_anchor_ranges(
2945    left: &Range<Anchor>,
2946    right: &Range<Anchor>,
2947    snapshot: &TextBufferSnapshot,
2948) -> Range<Anchor> {
2949    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2950        left.start
2951    } else {
2952        right.start
2953    };
2954    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2955        left.end
2956    } else {
2957        right.end
2958    };
2959    start..end
2960}
2961
2962#[derive(Error, Debug)]
2963#[error(
2964    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2965)]
2966pub struct ZedUpdateRequiredError {
2967    minimum_version: Version,
2968}
2969
2970#[derive(Debug, Clone, Copy)]
2971pub enum DataCollectionChoice {
2972    NotAnswered,
2973    Enabled,
2974    Disabled,
2975}
2976
2977impl DataCollectionChoice {
2978    pub fn is_enabled(self, cx: &App) -> bool {
2979        if cx.is_staff() {
2980            return true;
2981        }
2982        match self {
2983            Self::Enabled => true,
2984            Self::NotAnswered | Self::Disabled => false,
2985        }
2986    }
2987
2988    #[must_use]
2989    pub fn toggle(&self) -> DataCollectionChoice {
2990        match self {
2991            Self::Enabled => Self::Disabled,
2992            Self::Disabled => Self::Enabled,
2993            Self::NotAnswered => Self::Enabled,
2994        }
2995    }
2996}
2997
2998impl From<bool> for DataCollectionChoice {
2999    fn from(value: bool) -> Self {
3000        match value {
3001            true => DataCollectionChoice::Enabled,
3002            false => DataCollectionChoice::Disabled,
3003        }
3004    }
3005}
3006
3007struct ZedPredictUpsell;
3008
3009impl Dismissable for ZedPredictUpsell {
3010    const KEY: &'static str = "dismissed-edit-predict-upsell";
3011
3012    fn dismissed() -> bool {
3013        // To make this backwards compatible with older versions of Zed, we
3014        // check if the user has seen the previous Edit Prediction Onboarding
3015        // before, by checking the data collection choice which was written to
3016        // the database once the user clicked on "Accept and Enable"
3017        if KEY_VALUE_STORE
3018            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3019            .log_err()
3020            .is_some_and(|s| s.is_some())
3021        {
3022            return true;
3023        }
3024
3025        KEY_VALUE_STORE
3026            .read_kvp(Self::KEY)
3027            .log_err()
3028            .is_some_and(|s| s.is_some())
3029    }
3030}
3031
3032pub fn should_show_upsell_modal() -> bool {
3033    !ZedPredictUpsell::dismissed()
3034}
3035
3036pub fn init(cx: &mut App) {
3037    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3038        workspace.register_action(
3039            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3040                ZedPredictModal::toggle(
3041                    workspace,
3042                    workspace.user_store().clone(),
3043                    workspace.client().clone(),
3044                    window,
3045                    cx,
3046                )
3047            },
3048        );
3049
3050        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3051            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3052                settings
3053                    .project
3054                    .all_languages
3055                    .edit_predictions
3056                    .get_or_insert_default()
3057                    .provider = Some(EditPredictionProvider::None)
3058            });
3059        });
3060        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3061            EditPredictionStore::try_global(cx).and_then(|store| {
3062                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3063            })
3064        }
3065
3066        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3067            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3068                copilot_ui::initiate_sign_in(copilot, window, cx);
3069            }
3070        });
3071        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3072            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3073                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3074            }
3075        });
3076        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3077            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3078                copilot_ui::initiate_sign_out(copilot, window, cx);
3079            }
3080        });
3081    })
3082    .detach();
3083}