edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::{AppState, Workspace};
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  79use crate::example_spec::ExampleSpec;
  80use crate::license_detection::LicenseDetectionWatcher;
  81use crate::mercury::Mercury;
  82use crate::onboarding_modal::ZedPredictModal;
  83pub use crate::prediction::EditPrediction;
  84pub use crate::prediction::EditPredictionId;
  85use crate::prediction::EditPredictionResult;
  86pub use crate::sweep_ai::SweepAi;
  87pub use capture_example::capture_example;
  88pub use language_model::ApiKeyState;
  89pub use telemetry_events::EditPredictionRating;
  90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  91
  92actions!(
  93    edit_prediction,
  94    [
  95        /// Resets the edit prediction onboarding state.
  96        ResetOnboarding,
  97        /// Clears the edit prediction history.
  98        ClearHistory,
  99    ]
 100);
 101
 102/// Maximum number of events to track.
 103const EVENT_COUNT_MAX: usize = 10;
 104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 112
 113pub struct EditPredictionJumpsFeatureFlag;
 114
 115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 116    const NAME: &'static str = "edit_prediction_jumps";
 117}
 118
 119#[derive(Clone)]
 120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 121
 122impl Global for EditPredictionStoreGlobal {}
 123
 124/// Configuration for using the raw Zeta2 endpoint.
 125/// When set, the client uses the raw endpoint and constructs the prompt itself.
 126/// The version is also used as the Baseten environment name (lowercased).
 127#[derive(Clone)]
 128pub struct Zeta2RawConfig {
 129    pub model_id: Option<String>,
 130    pub environment: Option<String>,
 131    pub format: ZetaFormat,
 132}
 133
 134pub struct EditPredictionStore {
 135    client: Arc<Client>,
 136    user_store: Entity<UserStore>,
 137    llm_token: LlmApiToken,
 138    _fetch_experiments_task: Task<()>,
 139    projects: HashMap<EntityId, ProjectState>,
 140    update_required: bool,
 141    edit_prediction_model: EditPredictionModel,
 142    zeta2_raw_config: Option<Zeta2RawConfig>,
 143    preferred_experiment: Option<String>,
 144    available_experiments: Vec<String>,
 145    pub sweep_ai: SweepAi,
 146    pub mercury: Mercury,
 147    data_collection_choice: DataCollectionChoice,
 148    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 149    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 150    shown_predictions: VecDeque<EditPrediction>,
 151    rated_predictions: HashSet<EditPredictionId>,
 152    #[cfg(test)]
 153    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 154}
 155
 156pub(crate) struct EditPredictionRejectionPayload {
 157    rejection: EditPredictionRejection,
 158    organization_id: Option<OrganizationId>,
 159}
 160
 161#[derive(Copy, Clone, PartialEq, Eq)]
 162pub enum EditPredictionModel {
 163    Zeta,
 164    Fim { format: EditPredictionPromptFormat },
 165    Sweep,
 166    Mercury,
 167}
 168
 169#[derive(Clone)]
 170pub struct EditPredictionModelInput {
 171    project: Entity<Project>,
 172    buffer: Entity<Buffer>,
 173    snapshot: BufferSnapshot,
 174    position: Anchor,
 175    events: Vec<Arc<zeta_prompt::Event>>,
 176    related_files: Vec<RelatedFile>,
 177    recent_paths: VecDeque<ProjectPath>,
 178    trigger: PredictEditsRequestTrigger,
 179    diagnostic_search_range: Range<Point>,
 180    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 181    can_collect_data: bool,
 182    is_open_source: bool,
 183    pub user_actions: Vec<UserActionRecord>,
 184}
 185
 186#[derive(Debug)]
 187pub enum DebugEvent {
 188    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 189    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 190    EditPredictionStarted(EditPredictionStartedDebugEvent),
 191    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 192}
 193
 194#[derive(Debug)]
 195pub struct ContextRetrievalStartedDebugEvent {
 196    pub project_entity_id: EntityId,
 197    pub timestamp: Instant,
 198    pub search_prompt: String,
 199}
 200
 201#[derive(Debug)]
 202pub struct ContextRetrievalFinishedDebugEvent {
 203    pub project_entity_id: EntityId,
 204    pub timestamp: Instant,
 205    pub metadata: Vec<(&'static str, SharedString)>,
 206}
 207
 208#[derive(Debug)]
 209pub struct EditPredictionStartedDebugEvent {
 210    pub buffer: WeakEntity<Buffer>,
 211    pub position: Anchor,
 212    pub prompt: Option<String>,
 213}
 214
 215#[derive(Debug)]
 216pub struct EditPredictionFinishedDebugEvent {
 217    pub buffer: WeakEntity<Buffer>,
 218    pub position: Anchor,
 219    pub model_output: Option<String>,
 220}
 221
 222const USER_ACTION_HISTORY_SIZE: usize = 16;
 223
 224#[derive(Clone, Debug)]
 225pub struct UserActionRecord {
 226    pub action_type: UserActionType,
 227    pub buffer_id: EntityId,
 228    pub line_number: u32,
 229    pub offset: usize,
 230    pub timestamp_epoch_ms: u64,
 231}
 232
 233#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 234pub enum UserActionType {
 235    InsertChar,
 236    InsertSelection,
 237    DeleteChar,
 238    DeleteSelection,
 239    CursorMovement,
 240}
 241
 242/// An event with associated metadata for reconstructing buffer state.
 243#[derive(Clone)]
 244pub struct StoredEvent {
 245    pub event: Arc<zeta_prompt::Event>,
 246    pub old_snapshot: TextBufferSnapshot,
 247    pub new_snapshot_version: clock::Global,
 248    pub total_edit_range: Range<Anchor>,
 249}
 250
 251impl StoredEvent {
 252    fn can_merge(
 253        &self,
 254        next_old_event: &StoredEvent,
 255        latest_snapshot: &TextBufferSnapshot,
 256        latest_edit_range: &Range<Anchor>,
 257    ) -> bool {
 258        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 259        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 260            return false;
 261        }
 262        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 263            return false;
 264        }
 265        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 266            return false;
 267        }
 268        if !latest_snapshot
 269            .version
 270            .observed_all(&next_old_event.new_snapshot_version)
 271        {
 272            return false;
 273        }
 274
 275        let a_is_predicted = matches!(
 276            self.event.as_ref(),
 277            zeta_prompt::Event::BufferChange {
 278                predicted: true,
 279                ..
 280            }
 281        );
 282        let b_is_predicted = matches!(
 283            next_old_event.event.as_ref(),
 284            zeta_prompt::Event::BufferChange {
 285                predicted: true,
 286                ..
 287            }
 288        );
 289
 290        // If events come from the same source (both predicted or both manual) then
 291        // we would have coalesced them already.
 292        if a_is_predicted == b_is_predicted {
 293            return false;
 294        }
 295
 296        let left_range = self.total_edit_range.to_point(latest_snapshot);
 297        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 298        let latest_range = latest_edit_range.to_point(latest_snapshot);
 299
 300        // Events near to the latest edit are not merged if their sources differ.
 301        if lines_between_ranges(&left_range, &latest_range)
 302            .min(lines_between_ranges(&right_range, &latest_range))
 303            <= CHANGE_GROUPING_LINE_SPAN
 304        {
 305            return false;
 306        }
 307
 308        // Events that are distant from each other are not merged.
 309        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 310            return false;
 311        }
 312
 313        true
 314    }
 315}
 316
 317fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 318    if left.start > right.end {
 319        return left.start.row - right.end.row;
 320    }
 321    if right.start > left.end {
 322        return right.start.row - left.end.row;
 323    }
 324    0
 325}
 326
 327struct ProjectState {
 328    events: VecDeque<StoredEvent>,
 329    last_event: Option<LastEvent>,
 330    recent_paths: VecDeque<ProjectPath>,
 331    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 332    current_prediction: Option<CurrentEditPrediction>,
 333    next_pending_prediction_id: usize,
 334    pending_predictions: ArrayVec<PendingPrediction, 2>,
 335    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 336    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 337    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 338    cancelled_predictions: HashSet<usize>,
 339    context: Entity<RelatedExcerptStore>,
 340    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 341    user_actions: VecDeque<UserActionRecord>,
 342    _subscriptions: [gpui::Subscription; 2],
 343    copilot: Option<Entity<Copilot>>,
 344}
 345
 346impl ProjectState {
 347    fn record_user_action(&mut self, action: UserActionRecord) {
 348        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 349            self.user_actions.pop_front();
 350        }
 351        self.user_actions.push_back(action);
 352    }
 353
 354    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 355        self.events
 356            .iter()
 357            .cloned()
 358            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 359                let (one, two) = event.split_by_pause();
 360                let one = one.finalize(&self.license_detection_watchers, cx);
 361                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 362                one.into_iter().chain(two)
 363            }))
 364            .collect()
 365    }
 366
 367    fn cancel_pending_prediction(
 368        &mut self,
 369        pending_prediction: PendingPrediction,
 370        cx: &mut Context<EditPredictionStore>,
 371    ) {
 372        self.cancelled_predictions.insert(pending_prediction.id);
 373
 374        if pending_prediction.drop_on_cancel {
 375            drop(pending_prediction.task);
 376        } else {
 377            cx.spawn(async move |this, cx| {
 378                let Some(prediction_id) = pending_prediction.task.await else {
 379                    return;
 380                };
 381
 382                this.update(cx, |this, cx| {
 383                    this.reject_prediction(
 384                        prediction_id,
 385                        EditPredictionRejectReason::Canceled,
 386                        false,
 387                        None,
 388                        cx,
 389                    );
 390                })
 391                .ok();
 392            })
 393            .detach()
 394        }
 395    }
 396
 397    fn active_buffer(
 398        &self,
 399        project: &Entity<Project>,
 400        cx: &App,
 401    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 402        let project = project.read(cx);
 403        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 404        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 405        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 406        Some((active_buffer, registered_buffer.last_position))
 407    }
 408}
 409
 410#[derive(Debug, Clone)]
 411struct CurrentEditPrediction {
 412    pub requested_by: PredictionRequestedBy,
 413    pub prediction: EditPrediction,
 414    pub was_shown: bool,
 415    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 416}
 417
 418impl CurrentEditPrediction {
 419    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 420        let Some(new_edits) = self
 421            .prediction
 422            .interpolate(&self.prediction.buffer.read(cx))
 423        else {
 424            return false;
 425        };
 426
 427        if self.prediction.buffer != old_prediction.prediction.buffer {
 428            return true;
 429        }
 430
 431        let Some(old_edits) = old_prediction
 432            .prediction
 433            .interpolate(&old_prediction.prediction.buffer.read(cx))
 434        else {
 435            return true;
 436        };
 437
 438        let requested_by_buffer_id = self.requested_by.buffer_id();
 439
 440        // This reduces the occurrence of UI thrash from replacing edits
 441        //
 442        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 443        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 444            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 445            && old_edits.len() == 1
 446            && new_edits.len() == 1
 447        {
 448            let (old_range, old_text) = &old_edits[0];
 449            let (new_range, new_text) = &new_edits[0];
 450            new_range == old_range && new_text.starts_with(old_text.as_ref())
 451        } else {
 452            true
 453        }
 454    }
 455}
 456
 457#[derive(Debug, Clone)]
 458enum PredictionRequestedBy {
 459    DiagnosticsUpdate,
 460    Buffer(EntityId),
 461}
 462
 463impl PredictionRequestedBy {
 464    pub fn buffer_id(&self) -> Option<EntityId> {
 465        match self {
 466            PredictionRequestedBy::DiagnosticsUpdate => None,
 467            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 468        }
 469    }
 470}
 471
 472const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 473
 474#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 475pub enum DiagnosticSearchScope {
 476    Local,
 477    Global,
 478}
 479
 480#[derive(Debug)]
 481struct PendingPrediction {
 482    id: usize,
 483    task: Task<Option<EditPredictionId>>,
 484    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 485    /// If false, the task is awaited to completion so rejection can be reported.
 486    drop_on_cancel: bool,
 487}
 488
 489/// A prediction from the perspective of a buffer.
 490#[derive(Debug)]
 491enum BufferEditPrediction<'a> {
 492    Local { prediction: &'a EditPrediction },
 493    Jump { prediction: &'a EditPrediction },
 494}
 495
 496#[cfg(test)]
 497impl std::ops::Deref for BufferEditPrediction<'_> {
 498    type Target = EditPrediction;
 499
 500    fn deref(&self) -> &Self::Target {
 501        match self {
 502            BufferEditPrediction::Local { prediction } => prediction,
 503            BufferEditPrediction::Jump { prediction } => prediction,
 504        }
 505    }
 506}
 507
 508#[derive(Clone)]
 509struct PendingSettledPrediction {
 510    request_id: EditPredictionId,
 511    editable_anchor_range: Range<Anchor>,
 512    example: Option<ExampleSpec>,
 513    enqueued_at: Instant,
 514    last_edit_at: Instant,
 515}
 516
 517struct RegisteredBuffer {
 518    file: Option<Arc<dyn File>>,
 519    snapshot: TextBufferSnapshot,
 520    pending_predictions: Vec<PendingSettledPrediction>,
 521    last_position: Option<Anchor>,
 522    _subscriptions: [gpui::Subscription; 2],
 523}
 524
 525#[derive(Clone)]
 526struct LastEvent {
 527    old_snapshot: TextBufferSnapshot,
 528    new_snapshot: TextBufferSnapshot,
 529    old_file: Option<Arc<dyn File>>,
 530    new_file: Option<Arc<dyn File>>,
 531    latest_edit_range: Range<Anchor>,
 532    total_edit_range: Range<Anchor>,
 533    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 534    predicted: bool,
 535    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 536    last_edit_time: Option<Instant>,
 537}
 538
 539impl LastEvent {
 540    pub fn finalize(
 541        &self,
 542        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 543        cx: &App,
 544    ) -> Option<StoredEvent> {
 545        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 546        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 547
 548        let in_open_source_repo =
 549            [self.new_file.as_ref(), self.old_file.as_ref()]
 550                .iter()
 551                .all(|file| {
 552                    file.is_some_and(|file| {
 553                        license_detection_watchers
 554                            .get(&file.worktree_id(cx))
 555                            .is_some_and(|watcher| watcher.is_project_open_source())
 556                    })
 557                });
 558
 559        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 560            &self.old_snapshot,
 561            &self.new_snapshot,
 562            &self.total_edit_range,
 563        )?;
 564
 565        if path == old_path && diff.is_empty() {
 566            None
 567        } else {
 568            Some(StoredEvent {
 569                event: Arc::new(zeta_prompt::Event::BufferChange {
 570                    old_path,
 571                    path,
 572                    diff,
 573                    in_open_source_repo,
 574                    predicted: self.predicted,
 575                }),
 576                old_snapshot: self.old_snapshot.clone(),
 577                new_snapshot_version: self.new_snapshot.version.clone(),
 578                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 579                    ..self.new_snapshot.anchor_before(edit_range.end),
 580            })
 581        }
 582    }
 583
 584    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 585        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 586            return (self.clone(), None);
 587        };
 588
 589        let total_edit_range_before_pause = self
 590            .total_edit_range_at_last_pause_boundary
 591            .clone()
 592            .unwrap_or_else(|| self.total_edit_range.clone());
 593
 594        let Some(total_edit_range_after_pause) =
 595            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 596        else {
 597            return (self.clone(), None);
 598        };
 599
 600        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 601        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 602
 603        let before = LastEvent {
 604            old_snapshot: self.old_snapshot.clone(),
 605            new_snapshot: boundary_snapshot.clone(),
 606            old_file: self.old_file.clone(),
 607            new_file: self.new_file.clone(),
 608            latest_edit_range: latest_edit_range_before_pause,
 609            total_edit_range: total_edit_range_before_pause,
 610            total_edit_range_at_last_pause_boundary: None,
 611            predicted: self.predicted,
 612            snapshot_after_last_editing_pause: None,
 613            last_edit_time: self.last_edit_time,
 614        };
 615
 616        let after = LastEvent {
 617            old_snapshot: boundary_snapshot.clone(),
 618            new_snapshot: self.new_snapshot.clone(),
 619            old_file: self.old_file.clone(),
 620            new_file: self.new_file.clone(),
 621            latest_edit_range: latest_edit_range_after_pause,
 622            total_edit_range: total_edit_range_after_pause,
 623            total_edit_range_at_last_pause_boundary: None,
 624            predicted: self.predicted,
 625            snapshot_after_last_editing_pause: None,
 626            last_edit_time: self.last_edit_time,
 627        };
 628
 629        (before, Some(after))
 630    }
 631}
 632
 633fn compute_total_edit_range_between_snapshots(
 634    old_snapshot: &TextBufferSnapshot,
 635    new_snapshot: &TextBufferSnapshot,
 636) -> Option<Range<Anchor>> {
 637    let edits: Vec<Edit<usize>> = new_snapshot
 638        .edits_since::<usize>(&old_snapshot.version)
 639        .collect();
 640
 641    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 642    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 643    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 644
 645    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 646}
 647
 648fn compute_old_range_for_new_range(
 649    old_snapshot: &TextBufferSnapshot,
 650    new_snapshot: &TextBufferSnapshot,
 651    total_edit_range: &Range<Anchor>,
 652) -> Option<Range<Point>> {
 653    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 654    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 655
 656    let edits: Vec<Edit<usize>> = new_snapshot
 657        .edits_since::<usize>(&old_snapshot.version)
 658        .collect();
 659    let mut old_start_offset = None;
 660    let mut old_end_offset = None;
 661    let mut delta: isize = 0;
 662
 663    for edit in &edits {
 664        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 665            old_start_offset = Some(if new_start_offset < edit.new.start {
 666                new_start_offset.checked_add_signed(-delta)?
 667            } else {
 668                edit.old.start
 669            });
 670        }
 671
 672        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 673            old_end_offset = Some(if new_end_offset < edit.new.start {
 674                new_end_offset.checked_add_signed(-delta)?
 675            } else {
 676                edit.old.end
 677            });
 678        }
 679
 680        delta += edit.new.len() as isize - edit.old.len() as isize;
 681    }
 682
 683    let old_start_offset =
 684        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 685    let old_end_offset =
 686        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 687
 688    Some(
 689        old_snapshot.offset_to_point(old_start_offset)
 690            ..old_snapshot.offset_to_point(old_end_offset),
 691    )
 692}
 693
 694fn compute_diff_between_snapshots_in_range(
 695    old_snapshot: &TextBufferSnapshot,
 696    new_snapshot: &TextBufferSnapshot,
 697    total_edit_range: &Range<Anchor>,
 698) -> Option<(String, Range<Point>)> {
 699    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 700    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 701    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 702    let old_start_point = old_range.start;
 703    let old_end_point = old_range.end;
 704
 705    const CONTEXT_LINES: u32 = 3;
 706
 707    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 708    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 709    let old_context_end_row =
 710        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 711    let new_context_end_row =
 712        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 713
 714    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 715    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 716    let old_end_line_offset = old_snapshot
 717        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 718    let new_end_line_offset = new_snapshot
 719        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 720    let old_edit_range = old_start_line_offset..old_end_line_offset;
 721    let new_edit_range = new_start_line_offset..new_end_line_offset;
 722
 723    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 724    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 725
 726    let diff = language::unified_diff_with_offsets(
 727        &old_region_text,
 728        &new_region_text,
 729        old_context_start_row,
 730        new_context_start_row,
 731    );
 732
 733    Some((diff, new_start_point..new_end_point))
 734}
 735
 736fn buffer_path_with_id_fallback(
 737    file: Option<&Arc<dyn File>>,
 738    snapshot: &TextBufferSnapshot,
 739    cx: &App,
 740) -> Arc<Path> {
 741    if let Some(file) = file {
 742        file.full_path(cx).into()
 743    } else {
 744        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 745    }
 746}
 747
 748impl EditPredictionStore {
 749    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 750        cx.try_global::<EditPredictionStoreGlobal>()
 751            .map(|global| global.0.clone())
 752    }
 753
 754    pub fn global(
 755        client: &Arc<Client>,
 756        user_store: &Entity<UserStore>,
 757        cx: &mut App,
 758    ) -> Entity<Self> {
 759        cx.try_global::<EditPredictionStoreGlobal>()
 760            .map(|global| global.0.clone())
 761            .unwrap_or_else(|| {
 762                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 763                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 764                ep_store
 765            })
 766    }
 767
 768    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 769        let data_collection_choice = Self::load_data_collection_choice();
 770
 771        let llm_token = LlmApiToken::global(cx);
 772
 773        let (reject_tx, reject_rx) = mpsc::unbounded();
 774        cx.background_spawn({
 775            let client = client.clone();
 776            let llm_token = llm_token.clone();
 777            let app_version = AppVersion::global(cx);
 778            let background_executor = cx.background_executor().clone();
 779            async move {
 780                Self::handle_rejected_predictions(
 781                    reject_rx,
 782                    client,
 783                    llm_token,
 784                    app_version,
 785                    background_executor,
 786                )
 787                .await
 788            }
 789        })
 790        .detach();
 791
 792        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 793        cx.spawn(async move |this, cx| {
 794            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 795        })
 796        .detach();
 797
 798        let mut current_user = user_store.read(cx).watch_current_user();
 799        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 800            while current_user.borrow().is_none() {
 801                current_user.next().await;
 802            }
 803            this.update(cx, |this, cx| {
 804                this.refresh_available_experiments(cx);
 805            })
 806            .log_err();
 807        });
 808
 809        let this = Self {
 810            projects: HashMap::default(),
 811            client,
 812            user_store,
 813            llm_token,
 814            _fetch_experiments_task: fetch_experiments_task,
 815            update_required: false,
 816            edit_prediction_model: EditPredictionModel::Zeta,
 817            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 818            preferred_experiment: None,
 819            available_experiments: Vec::new(),
 820            sweep_ai: SweepAi::new(cx),
 821            mercury: Mercury::new(cx),
 822
 823            data_collection_choice,
 824            reject_predictions_tx: reject_tx,
 825            settled_predictions_tx,
 826            rated_predictions: Default::default(),
 827            shown_predictions: Default::default(),
 828            #[cfg(test)]
 829            settled_event_callback: None,
 830        };
 831
 832        this
 833    }
 834
 835    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 836        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 837        let format = ZetaFormat::parse(&version_str).ok()?;
 838        let model_id = env::var("ZED_ZETA_MODEL").ok();
 839        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 840        Some(Zeta2RawConfig {
 841            model_id,
 842            environment,
 843            format,
 844        })
 845    }
 846
 847    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 848        self.edit_prediction_model = model;
 849    }
 850
 851    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 852        self.zeta2_raw_config = Some(config);
 853    }
 854
 855    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 856        self.zeta2_raw_config.as_ref()
 857    }
 858
 859    pub fn preferred_experiment(&self) -> Option<&str> {
 860        self.preferred_experiment.as_deref()
 861    }
 862
 863    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 864        self.preferred_experiment = experiment;
 865    }
 866
 867    pub fn available_experiments(&self) -> &[String] {
 868        &self.available_experiments
 869    }
 870
 871    pub fn active_experiment(&self) -> Option<&str> {
 872        self.preferred_experiment.as_deref().or_else(|| {
 873            self.shown_predictions
 874                .iter()
 875                .find_map(|p| p.model_version.as_ref())
 876                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 877        })
 878    }
 879
 880    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 881        let client = self.client.clone();
 882        let llm_token = self.llm_token.clone();
 883        let app_version = AppVersion::global(cx);
 884        let organization_id = self
 885            .user_store
 886            .read(cx)
 887            .current_organization()
 888            .map(|organization| organization.id.clone());
 889
 890        cx.spawn(async move |this, cx| {
 891            let experiments = cx
 892                .background_spawn(async move {
 893                    let http_client = client.http_client();
 894                    let token = llm_token.acquire(&client, organization_id).await?;
 895                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 896                    let request = http_client::Request::builder()
 897                        .method(Method::GET)
 898                        .uri(url.as_ref())
 899                        .header("Authorization", format!("Bearer {}", token))
 900                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 901                        .body(Default::default())?;
 902                    let mut response = http_client.send(request).await?;
 903                    if response.status().is_success() {
 904                        let mut body = Vec::new();
 905                        response.body_mut().read_to_end(&mut body).await?;
 906                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 907                        Ok(experiments)
 908                    } else {
 909                        let mut body = String::new();
 910                        response.body_mut().read_to_string(&mut body).await?;
 911                        anyhow::bail!(
 912                            "Failed to fetch experiments: {:?}\nBody: {}",
 913                            response.status(),
 914                            body
 915                        );
 916                    }
 917                })
 918                .await?;
 919            this.update(cx, |this, cx| {
 920                this.available_experiments = experiments;
 921                cx.notify();
 922            })?;
 923            anyhow::Ok(())
 924        })
 925        .detach_and_log_err(cx);
 926    }
 927
 928    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 929        use ui::IconName;
 930        match self.edit_prediction_model {
 931            EditPredictionModel::Sweep => {
 932                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 933                    .with_disabled(IconName::SweepAiDisabled)
 934                    .with_up(IconName::SweepAiUp)
 935                    .with_down(IconName::SweepAiDown)
 936                    .with_error(IconName::SweepAiError)
 937            }
 938            EditPredictionModel::Mercury => {
 939                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 940            }
 941            EditPredictionModel::Zeta => {
 942                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 943                    .with_disabled(IconName::ZedPredictDisabled)
 944                    .with_up(IconName::ZedPredictUp)
 945                    .with_down(IconName::ZedPredictDown)
 946                    .with_error(IconName::ZedPredictError)
 947            }
 948            EditPredictionModel::Fim { .. } => {
 949                let settings = &all_language_settings(None, cx).edit_predictions;
 950                match settings.provider {
 951                    EditPredictionProvider::Ollama => {
 952                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 953                    }
 954                    _ => {
 955                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 956                    }
 957                }
 958            }
 959        }
 960    }
 961
 962    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 963        self.sweep_ai.api_token.read(cx).has_key()
 964    }
 965
 966    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 967        self.mercury.api_token.read(cx).has_key()
 968    }
 969
 970    pub fn mercury_has_payment_required_error(&self) -> bool {
 971        self.mercury.has_payment_required_error()
 972    }
 973
 974    pub fn clear_history(&mut self) {
 975        for project_state in self.projects.values_mut() {
 976            project_state.events.clear();
 977            project_state.last_event.take();
 978        }
 979    }
 980
 981    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 982        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 983            project_state.events.clear();
 984            project_state.last_event.take();
 985        }
 986    }
 987
 988    pub fn edit_history_for_project(
 989        &self,
 990        project: &Entity<Project>,
 991        cx: &App,
 992    ) -> Vec<StoredEvent> {
 993        self.projects
 994            .get(&project.entity_id())
 995            .map(|project_state| project_state.events(cx))
 996            .unwrap_or_default()
 997    }
 998
 999    pub fn context_for_project<'a>(
1000        &'a self,
1001        project: &Entity<Project>,
1002        cx: &'a mut App,
1003    ) -> Vec<RelatedFile> {
1004        self.projects
1005            .get(&project.entity_id())
1006            .map(|project_state| {
1007                project_state.context.update(cx, |context, cx| {
1008                    context
1009                        .related_files_with_buffers(cx)
1010                        .map(|(mut related_file, buffer)| {
1011                            related_file.in_open_source_repo = buffer
1012                                .read(cx)
1013                                .file()
1014                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1015                            related_file
1016                        })
1017                        .collect()
1018                })
1019            })
1020            .unwrap_or_default()
1021    }
1022
1023    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1024        self.projects
1025            .get(&project.entity_id())
1026            .and_then(|project| project.copilot.clone())
1027    }
1028
1029    pub fn start_copilot_for_project(
1030        &mut self,
1031        project: &Entity<Project>,
1032        cx: &mut Context<Self>,
1033    ) -> Option<Entity<Copilot>> {
1034        if DisableAiSettings::get(None, cx).disable_ai {
1035            return None;
1036        }
1037        let state = self.get_or_init_project(project, cx);
1038
1039        if state.copilot.is_some() {
1040            return state.copilot.clone();
1041        }
1042        let _project = project.clone();
1043        let project = project.read(cx);
1044
1045        let node = project.node_runtime().cloned();
1046        if let Some(node) = node {
1047            let next_id = project.languages().next_language_server_id();
1048            let fs = project.fs().clone();
1049
1050            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1051            state.copilot = Some(copilot.clone());
1052            Some(copilot)
1053        } else {
1054            None
1055        }
1056    }
1057
1058    pub fn context_for_project_with_buffers<'a>(
1059        &'a self,
1060        project: &Entity<Project>,
1061        cx: &'a mut App,
1062    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1063        self.projects
1064            .get(&project.entity_id())
1065            .map(|project| {
1066                project.context.update(cx, |context, cx| {
1067                    context.related_files_with_buffers(cx).collect()
1068                })
1069            })
1070            .unwrap_or_default()
1071    }
1072
1073    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1074        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1075            self.user_store.read(cx).edit_prediction_usage()
1076        } else {
1077            None
1078        }
1079    }
1080
1081    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1082        self.get_or_init_project(project, cx);
1083    }
1084
1085    pub fn register_buffer(
1086        &mut self,
1087        buffer: &Entity<Buffer>,
1088        project: &Entity<Project>,
1089        cx: &mut Context<Self>,
1090    ) {
1091        let project_state = self.get_or_init_project(project, cx);
1092        Self::register_buffer_impl(project_state, buffer, project, cx);
1093    }
1094
1095    fn get_or_init_project(
1096        &mut self,
1097        project: &Entity<Project>,
1098        cx: &mut Context<Self>,
1099    ) -> &mut ProjectState {
1100        let entity_id = project.entity_id();
1101        self.projects
1102            .entry(entity_id)
1103            .or_insert_with(|| ProjectState {
1104                context: {
1105                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1106                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1107                        this.handle_excerpt_store_event(entity_id, event);
1108                    })
1109                    .detach();
1110                    related_excerpt_store
1111                },
1112                events: VecDeque::new(),
1113                last_event: None,
1114                recent_paths: VecDeque::new(),
1115                debug_tx: None,
1116                registered_buffers: HashMap::default(),
1117                current_prediction: None,
1118                cancelled_predictions: HashSet::default(),
1119                pending_predictions: ArrayVec::new(),
1120                next_pending_prediction_id: 0,
1121                last_edit_prediction_refresh: None,
1122                last_jump_prediction_refresh: None,
1123                license_detection_watchers: HashMap::default(),
1124                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1125                _subscriptions: [
1126                    cx.subscribe(&project, Self::handle_project_event),
1127                    cx.observe_release(&project, move |this, _, cx| {
1128                        this.projects.remove(&entity_id);
1129                        cx.notify();
1130                    }),
1131                ],
1132                copilot: None,
1133            })
1134    }
1135
1136    pub fn remove_project(&mut self, project: &Entity<Project>) {
1137        self.projects.remove(&project.entity_id());
1138    }
1139
1140    fn handle_excerpt_store_event(
1141        &mut self,
1142        project_entity_id: EntityId,
1143        event: &RelatedExcerptStoreEvent,
1144    ) {
1145        if let Some(project_state) = self.projects.get(&project_entity_id) {
1146            if let Some(debug_tx) = project_state.debug_tx.clone() {
1147                match event {
1148                    RelatedExcerptStoreEvent::StartedRefresh => {
1149                        debug_tx
1150                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1151                                ContextRetrievalStartedDebugEvent {
1152                                    project_entity_id: project_entity_id,
1153                                    timestamp: Instant::now(),
1154                                    search_prompt: String::new(),
1155                                },
1156                            ))
1157                            .ok();
1158                    }
1159                    RelatedExcerptStoreEvent::FinishedRefresh {
1160                        cache_hit_count,
1161                        cache_miss_count,
1162                        mean_definition_latency,
1163                        max_definition_latency,
1164                    } => {
1165                        debug_tx
1166                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1167                                ContextRetrievalFinishedDebugEvent {
1168                                    project_entity_id: project_entity_id,
1169                                    timestamp: Instant::now(),
1170                                    metadata: vec![
1171                                        (
1172                                            "Cache Hits",
1173                                            format!(
1174                                                "{}/{}",
1175                                                cache_hit_count,
1176                                                cache_hit_count + cache_miss_count
1177                                            )
1178                                            .into(),
1179                                        ),
1180                                        (
1181                                            "Max LSP Time",
1182                                            format!("{} ms", max_definition_latency.as_millis())
1183                                                .into(),
1184                                        ),
1185                                        (
1186                                            "Mean LSP Time",
1187                                            format!("{} ms", mean_definition_latency.as_millis())
1188                                                .into(),
1189                                        ),
1190                                    ],
1191                                },
1192                            ))
1193                            .ok();
1194                    }
1195                }
1196            }
1197        }
1198    }
1199
1200    pub fn debug_info(
1201        &mut self,
1202        project: &Entity<Project>,
1203        cx: &mut Context<Self>,
1204    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1205        let project_state = self.get_or_init_project(project, cx);
1206        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1207        project_state.debug_tx = Some(debug_watch_tx);
1208        debug_watch_rx
1209    }
1210
1211    fn handle_project_event(
1212        &mut self,
1213        project: Entity<Project>,
1214        event: &project::Event,
1215        cx: &mut Context<Self>,
1216    ) {
1217        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1218            return;
1219        }
1220        // TODO [zeta2] init with recent paths
1221        match event {
1222            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1223                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1224                    return;
1225                };
1226                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1227                if let Some(path) = path {
1228                    if let Some(ix) = project_state
1229                        .recent_paths
1230                        .iter()
1231                        .position(|probe| probe == &path)
1232                    {
1233                        project_state.recent_paths.remove(ix);
1234                    }
1235                    project_state.recent_paths.push_front(path);
1236                }
1237            }
1238            project::Event::DiagnosticsUpdated { .. } => {
1239                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1240                    self.refresh_prediction_from_diagnostics(
1241                        project,
1242                        DiagnosticSearchScope::Global,
1243                        cx,
1244                    );
1245                }
1246            }
1247            _ => (),
1248        }
1249    }
1250
1251    fn register_buffer_impl<'a>(
1252        project_state: &'a mut ProjectState,
1253        buffer: &Entity<Buffer>,
1254        project: &Entity<Project>,
1255        cx: &mut Context<Self>,
1256    ) -> &'a mut RegisteredBuffer {
1257        let buffer_id = buffer.entity_id();
1258
1259        if let Some(file) = buffer.read(cx).file() {
1260            let worktree_id = file.worktree_id(cx);
1261            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1262                project_state
1263                    .license_detection_watchers
1264                    .entry(worktree_id)
1265                    .or_insert_with(|| {
1266                        let project_entity_id = project.entity_id();
1267                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1268                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1269                            else {
1270                                return;
1271                            };
1272                            project_state
1273                                .license_detection_watchers
1274                                .remove(&worktree_id);
1275                        })
1276                        .detach();
1277                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1278                    });
1279            }
1280        }
1281
1282        match project_state.registered_buffers.entry(buffer_id) {
1283            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1284            hash_map::Entry::Vacant(entry) => {
1285                let buf = buffer.read(cx);
1286                let snapshot = buf.text_snapshot();
1287                let file = buf.file().cloned();
1288                let project_entity_id = project.entity_id();
1289                entry.insert(RegisteredBuffer {
1290                    snapshot,
1291                    file,
1292                    last_position: None,
1293                    pending_predictions: Vec::new(),
1294                    _subscriptions: [
1295                        cx.subscribe(buffer, {
1296                            let project = project.downgrade();
1297                            move |this, buffer, event, cx| {
1298                                if let language::BufferEvent::Edited { is_local } = event
1299                                    && let Some(project) = project.upgrade()
1300                                {
1301                                    this.report_changes_for_buffer(
1302                                        &buffer, &project, false, *is_local, cx,
1303                                    );
1304                                }
1305                            }
1306                        }),
1307                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1308                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1309                            else {
1310                                return;
1311                            };
1312                            project_state.registered_buffers.remove(&buffer_id);
1313                        }),
1314                    ],
1315                })
1316            }
1317        }
1318    }
1319
1320    fn report_changes_for_buffer(
1321        &mut self,
1322        buffer: &Entity<Buffer>,
1323        project: &Entity<Project>,
1324        is_predicted: bool,
1325        is_local: bool,
1326        cx: &mut Context<Self>,
1327    ) {
1328        let project_state = self.get_or_init_project(project, cx);
1329        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1330
1331        let buf = buffer.read(cx);
1332        let new_file = buf.file().cloned();
1333        let new_snapshot = buf.text_snapshot();
1334        if new_snapshot.version == registered_buffer.snapshot.version {
1335            return;
1336        }
1337        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1338        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1339        let mut num_edits = 0usize;
1340        let mut total_deleted = 0usize;
1341        let mut total_inserted = 0usize;
1342        let mut edit_range: Option<Range<Anchor>> = None;
1343        let mut last_offset: Option<usize> = None;
1344        let now = cx.background_executor().now();
1345
1346        for (edit, anchor_range) in
1347            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1348        {
1349            num_edits += 1;
1350            total_deleted += edit.old.len();
1351            total_inserted += edit.new.len();
1352            edit_range = Some(match edit_range {
1353                None => anchor_range,
1354                Some(acc) => acc.start..anchor_range.end,
1355            });
1356            last_offset = Some(edit.new.end);
1357        }
1358
1359        let Some(edit_range) = edit_range else {
1360            return;
1361        };
1362
1363        for pending_prediction in &mut registered_buffer.pending_predictions {
1364            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1365                pending_prediction.last_edit_at = now;
1366            }
1367        }
1368
1369        let include_in_history = is_local
1370            || collaborator_edit_overlaps_locality_region(
1371                project_state,
1372                project,
1373                buffer,
1374                &buf.snapshot(),
1375                &edit_range,
1376                cx,
1377            );
1378
1379        if is_local {
1380            let action_type = match (total_deleted, total_inserted, num_edits) {
1381                (0, ins, n) if ins == n => UserActionType::InsertChar,
1382                (0, _, _) => UserActionType::InsertSelection,
1383                (del, 0, n) if del == n => UserActionType::DeleteChar,
1384                (_, 0, _) => UserActionType::DeleteSelection,
1385                (_, ins, n) if ins == n => UserActionType::InsertChar,
1386                (_, _, _) => UserActionType::InsertSelection,
1387            };
1388
1389            if let Some(offset) = last_offset {
1390                let point = new_snapshot.offset_to_point(offset);
1391                let timestamp_epoch_ms = SystemTime::now()
1392                    .duration_since(UNIX_EPOCH)
1393                    .map(|d| d.as_millis() as u64)
1394                    .unwrap_or(0);
1395                project_state.record_user_action(UserActionRecord {
1396                    action_type,
1397                    buffer_id: buffer.entity_id(),
1398                    line_number: point.row,
1399                    offset,
1400                    timestamp_epoch_ms,
1401                });
1402            }
1403        }
1404
1405        if !include_in_history {
1406            return;
1407        }
1408
1409        let events = &mut project_state.events;
1410
1411        if let Some(last_event) = project_state.last_event.as_mut() {
1412            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1413                == last_event.new_snapshot.remote_id()
1414                && old_snapshot.version == last_event.new_snapshot.version;
1415
1416            let prediction_source_changed = is_predicted != last_event.predicted;
1417
1418            let should_coalesce = is_next_snapshot_of_same_buffer
1419                && !prediction_source_changed
1420                && lines_between_ranges(
1421                    &edit_range.to_point(&new_snapshot),
1422                    &last_event.latest_edit_range.to_point(&new_snapshot),
1423                ) <= CHANGE_GROUPING_LINE_SPAN;
1424
1425            if should_coalesce {
1426                let pause_elapsed = last_event
1427                    .last_edit_time
1428                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1429                    .unwrap_or(false);
1430                if pause_elapsed {
1431                    last_event.snapshot_after_last_editing_pause =
1432                        Some(last_event.new_snapshot.clone());
1433                    last_event.total_edit_range_at_last_pause_boundary =
1434                        Some(last_event.total_edit_range.clone());
1435                }
1436
1437                last_event.latest_edit_range = edit_range.clone();
1438                last_event.total_edit_range =
1439                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1440                last_event.new_snapshot = new_snapshot;
1441                last_event.last_edit_time = Some(now);
1442                return;
1443            }
1444        }
1445
1446        if let Some(event) = project_state.last_event.take() {
1447            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1448                if events.len() + 1 >= EVENT_COUNT_MAX {
1449                    events.pop_front();
1450                }
1451                events.push_back(event);
1452            }
1453        }
1454
1455        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1456
1457        project_state.last_event = Some(LastEvent {
1458            old_file,
1459            new_file,
1460            old_snapshot,
1461            new_snapshot,
1462            latest_edit_range: edit_range.clone(),
1463            total_edit_range: edit_range,
1464            total_edit_range_at_last_pause_boundary: None,
1465            predicted: is_predicted,
1466            snapshot_after_last_editing_pause: None,
1467            last_edit_time: Some(now),
1468        });
1469    }
1470
1471    fn prediction_at(
1472        &mut self,
1473        buffer: &Entity<Buffer>,
1474        position: Option<language::Anchor>,
1475        project: &Entity<Project>,
1476        cx: &App,
1477    ) -> Option<BufferEditPrediction<'_>> {
1478        let project_state = self.projects.get_mut(&project.entity_id())?;
1479        if let Some(position) = position
1480            && let Some(buffer) = project_state
1481                .registered_buffers
1482                .get_mut(&buffer.entity_id())
1483        {
1484            buffer.last_position = Some(position);
1485        }
1486
1487        let CurrentEditPrediction {
1488            requested_by,
1489            prediction,
1490            ..
1491        } = project_state.current_prediction.as_ref()?;
1492
1493        if prediction.targets_buffer(buffer.read(cx)) {
1494            Some(BufferEditPrediction::Local { prediction })
1495        } else {
1496            let show_jump = match requested_by {
1497                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1498                    requested_by_buffer_id == &buffer.entity_id()
1499                }
1500                PredictionRequestedBy::DiagnosticsUpdate => true,
1501            };
1502
1503            if show_jump {
1504                Some(BufferEditPrediction::Jump { prediction })
1505            } else {
1506                None
1507            }
1508        }
1509    }
1510
1511    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1512        let Some(current_prediction) = self
1513            .projects
1514            .get_mut(&project.entity_id())
1515            .and_then(|project_state| project_state.current_prediction.take())
1516        else {
1517            return;
1518        };
1519
1520        self.report_changes_for_buffer(
1521            &current_prediction.prediction.buffer,
1522            project,
1523            true,
1524            true,
1525            cx,
1526        );
1527
1528        // can't hold &mut project_state ref across report_changes_for_buffer_call
1529        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1530            return;
1531        };
1532
1533        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1534            project_state.cancel_pending_prediction(pending_prediction, cx);
1535        }
1536
1537        match self.edit_prediction_model {
1538            EditPredictionModel::Sweep => {
1539                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1540            }
1541            EditPredictionModel::Mercury => {
1542                mercury::edit_prediction_accepted(
1543                    current_prediction.prediction.id,
1544                    self.client.http_client(),
1545                    cx,
1546                );
1547            }
1548            EditPredictionModel::Zeta => {
1549                let is_cloud = !matches!(
1550                    all_language_settings(None, cx).edit_predictions.provider,
1551                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1552                );
1553                if is_cloud {
1554                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1555                }
1556            }
1557            EditPredictionModel::Fim { .. } => {}
1558        }
1559    }
1560
1561    async fn handle_rejected_predictions(
1562        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1563        client: Arc<Client>,
1564        llm_token: LlmApiToken,
1565        app_version: Version,
1566        background_executor: BackgroundExecutor,
1567    ) {
1568        let mut rx = std::pin::pin!(rx.peekable());
1569        let mut batched = Vec::new();
1570
1571        while let Some(EditPredictionRejectionPayload {
1572            rejection,
1573            organization_id,
1574        }) = rx.next().await
1575        {
1576            batched.push(rejection);
1577
1578            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1579                select_biased! {
1580                    next = rx.as_mut().peek().fuse() => {
1581                        if next.is_some() {
1582                            continue;
1583                        }
1584                    }
1585                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1586                }
1587            }
1588
1589            let url = client
1590                .http_client()
1591                .build_zed_llm_url("/predict_edits/reject", &[])
1592                .unwrap();
1593
1594            let flush_count = batched
1595                .len()
1596                // in case items have accumulated after failure
1597                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1598            let start = batched.len() - flush_count;
1599
1600            let body = RejectEditPredictionsBodyRef {
1601                rejections: &batched[start..],
1602            };
1603
1604            let result = Self::send_api_request::<()>(
1605                |builder| {
1606                    let req = builder
1607                        .uri(url.as_ref())
1608                        .body(serde_json::to_string(&body)?.into());
1609                    anyhow::Ok(req?)
1610                },
1611                client.clone(),
1612                llm_token.clone(),
1613                organization_id,
1614                app_version.clone(),
1615                true,
1616            )
1617            .await;
1618
1619            if result.log_err().is_some() {
1620                batched.drain(start..);
1621            }
1622        }
1623    }
1624
1625    async fn run_settled_predictions_worker(
1626        this: WeakEntity<Self>,
1627        mut rx: UnboundedReceiver<Instant>,
1628        cx: &mut AsyncApp,
1629    ) {
1630        let mut next_wake_time: Option<Instant> = None;
1631        loop {
1632            let now = cx.background_executor().now();
1633            if let Some(wake_time) = next_wake_time.take() {
1634                cx.background_executor()
1635                    .timer(wake_time.duration_since(now))
1636                    .await;
1637            } else {
1638                let Some(new_enqueue_time) = rx.next().await else {
1639                    break;
1640                };
1641                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1642                while rx.next().now_or_never().flatten().is_some() {}
1643                continue;
1644            }
1645
1646            let Some(this) = this.upgrade() else {
1647                break;
1648            };
1649
1650            let now = cx.background_executor().now();
1651
1652            let mut oldest_edited_at = None;
1653
1654            this.update(cx, |this, _| {
1655                for (_, project_state) in this.projects.iter_mut() {
1656                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1657                        registered_buffer
1658                            .pending_predictions
1659                            .retain_mut(|pending_prediction| {
1660                                let age =
1661                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1662                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1663                                    return false;
1664                                }
1665
1666                                let quiet_for =
1667                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1668                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1669                                    let settled_editable_region = registered_buffer
1670                                        .snapshot
1671                                        .text_for_range(
1672                                            pending_prediction.editable_anchor_range.clone(),
1673                                        )
1674                                        .collect::<String>();
1675
1676                                    #[cfg(test)]
1677                                    if let Some(callback) = &this.settled_event_callback {
1678                                        callback(
1679                                            pending_prediction.request_id.clone(),
1680                                            settled_editable_region.clone(),
1681                                        );
1682                                    }
1683
1684                                    telemetry::event!(
1685                                        EDIT_PREDICTION_SETTLED_EVENT,
1686                                        request_id = pending_prediction.request_id.0.clone(),
1687                                        settled_editable_region,
1688                                        example = pending_prediction.example.take(),
1689                                    );
1690
1691                                    return false;
1692                                }
1693
1694                                if oldest_edited_at
1695                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1696                                {
1697                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1698                                }
1699
1700                                true
1701                            });
1702                    }
1703                }
1704            });
1705
1706            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1707        }
1708    }
1709
1710    pub(crate) fn enqueue_settled_prediction(
1711        &mut self,
1712        request_id: EditPredictionId,
1713        project: &Entity<Project>,
1714        edited_buffer: &Entity<Buffer>,
1715        edited_buffer_snapshot: &BufferSnapshot,
1716        editable_offset_range: Range<usize>,
1717        example: Option<ExampleSpec>,
1718        cx: &mut Context<Self>,
1719    ) {
1720        let this = &mut *self;
1721        let project_state = this.get_or_init_project(project, cx);
1722        if let Some(buffer) = project_state
1723            .registered_buffers
1724            .get_mut(&edited_buffer.entity_id())
1725        {
1726            let now = cx.background_executor().now();
1727            buffer.pending_predictions.push(PendingSettledPrediction {
1728                request_id: request_id,
1729                editable_anchor_range: edited_buffer_snapshot
1730                    .anchor_range_around(editable_offset_range),
1731                example,
1732                enqueued_at: now,
1733                last_edit_at: now,
1734            });
1735            this.settled_predictions_tx.unbounded_send(now).ok();
1736        }
1737    }
1738
1739    fn reject_current_prediction(
1740        &mut self,
1741        reason: EditPredictionRejectReason,
1742        project: &Entity<Project>,
1743        cx: &App,
1744    ) {
1745        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1746            project_state.pending_predictions.clear();
1747            if let Some(prediction) = project_state.current_prediction.take() {
1748                let model_version = prediction.prediction.model_version.clone();
1749                self.reject_prediction(
1750                    prediction.prediction.id,
1751                    reason,
1752                    prediction.was_shown,
1753                    model_version,
1754                    cx,
1755                );
1756            }
1757        };
1758    }
1759
1760    fn did_show_current_prediction(
1761        &mut self,
1762        project: &Entity<Project>,
1763        display_type: edit_prediction_types::SuggestionDisplayType,
1764        cx: &mut Context<Self>,
1765    ) {
1766        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1767            return;
1768        };
1769
1770        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1771            return;
1772        };
1773
1774        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1775        let previous_shown_with = current_prediction.shown_with;
1776
1777        if previous_shown_with.is_none() || !is_jump {
1778            current_prediction.shown_with = Some(display_type);
1779        }
1780
1781        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1782
1783        if is_first_non_jump_show {
1784            current_prediction.was_shown = true;
1785        }
1786
1787        let display_type_changed = previous_shown_with != Some(display_type);
1788
1789        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1790            sweep_ai::edit_prediction_shown(
1791                &self.sweep_ai,
1792                self.client.clone(),
1793                &current_prediction.prediction,
1794                display_type,
1795                cx,
1796            );
1797        }
1798
1799        if is_first_non_jump_show {
1800            self.shown_predictions
1801                .push_front(current_prediction.prediction.clone());
1802            if self.shown_predictions.len() > 50 {
1803                let completion = self.shown_predictions.pop_back().unwrap();
1804                self.rated_predictions.remove(&completion.id);
1805            }
1806        }
1807    }
1808
1809    fn reject_prediction(
1810        &mut self,
1811        prediction_id: EditPredictionId,
1812        reason: EditPredictionRejectReason,
1813        was_shown: bool,
1814        model_version: Option<String>,
1815        cx: &App,
1816    ) {
1817        match self.edit_prediction_model {
1818            EditPredictionModel::Zeta => {
1819                let is_cloud = !matches!(
1820                    all_language_settings(None, cx).edit_predictions.provider,
1821                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1822                );
1823
1824                if is_cloud {
1825                    let organization_id = self
1826                        .user_store
1827                        .read(cx)
1828                        .current_organization()
1829                        .map(|organization| organization.id.clone());
1830
1831                    self.reject_predictions_tx
1832                        .unbounded_send(EditPredictionRejectionPayload {
1833                            rejection: EditPredictionRejection {
1834                                request_id: prediction_id.to_string(),
1835                                reason,
1836                                was_shown,
1837                                model_version,
1838                            },
1839                            organization_id,
1840                        })
1841                        .log_err();
1842                }
1843            }
1844            EditPredictionModel::Mercury => {
1845                mercury::edit_prediction_rejected(
1846                    prediction_id,
1847                    was_shown,
1848                    reason,
1849                    self.client.http_client(),
1850                    cx,
1851                );
1852            }
1853            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1854        }
1855    }
1856
1857    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1858        self.projects
1859            .get(&project.entity_id())
1860            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1861    }
1862
1863    pub fn refresh_prediction_from_buffer(
1864        &mut self,
1865        project: Entity<Project>,
1866        buffer: Entity<Buffer>,
1867        position: language::Anchor,
1868        cx: &mut Context<Self>,
1869    ) {
1870        self.queue_prediction_refresh(
1871            project.clone(),
1872            PredictEditsRequestTrigger::Other,
1873            buffer.entity_id(),
1874            cx,
1875            move |this, cx| {
1876                let Some(request_task) = this
1877                    .update(cx, |this, cx| {
1878                        this.request_prediction(
1879                            &project,
1880                            &buffer,
1881                            position,
1882                            PredictEditsRequestTrigger::Other,
1883                            cx,
1884                        )
1885                    })
1886                    .log_err()
1887                else {
1888                    return Task::ready(anyhow::Ok(None));
1889                };
1890
1891                cx.spawn(async move |_cx| {
1892                    request_task.await.map(|prediction_result| {
1893                        prediction_result.map(|prediction_result| {
1894                            (
1895                                prediction_result,
1896                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1897                            )
1898                        })
1899                    })
1900                })
1901            },
1902        )
1903    }
1904
1905    pub fn refresh_prediction_from_diagnostics(
1906        &mut self,
1907        project: Entity<Project>,
1908        scope: DiagnosticSearchScope,
1909        cx: &mut Context<Self>,
1910    ) {
1911        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1912            return;
1913        }
1914
1915        if currently_following(&project, cx) {
1916            return;
1917        }
1918
1919        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1920            return;
1921        };
1922
1923        // Prefer predictions from buffer
1924        if project_state.current_prediction.is_some() {
1925            log::debug!(
1926                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1927            );
1928            return;
1929        }
1930
1931        self.queue_prediction_refresh(
1932            project.clone(),
1933            PredictEditsRequestTrigger::Diagnostics,
1934            project.entity_id(),
1935            cx,
1936            move |this, cx| {
1937                let Some((active_buffer, snapshot, cursor_point)) = this
1938                    .read_with(cx, |this, cx| {
1939                        let project_state = this.projects.get(&project.entity_id())?;
1940                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1941                        let snapshot = buffer.read(cx).snapshot();
1942
1943                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1944                            return None;
1945                        }
1946
1947                        let cursor_point = position
1948                            .map(|pos| pos.to_point(&snapshot))
1949                            .unwrap_or_default();
1950
1951                        Some((buffer, snapshot, cursor_point))
1952                    })
1953                    .log_err()
1954                    .flatten()
1955                else {
1956                    return Task::ready(anyhow::Ok(None));
1957                };
1958
1959                cx.spawn(async move |cx| {
1960                    let diagnostic_search_range = match scope {
1961                        DiagnosticSearchScope::Local => {
1962                            let diagnostic_search_start =
1963                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1964                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1965                            Point::new(diagnostic_search_start, 0)
1966                                ..Point::new(diagnostic_search_end, 0)
1967                        }
1968                        DiagnosticSearchScope::Global => Default::default(),
1969                    };
1970
1971                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1972                        active_buffer,
1973                        &snapshot,
1974                        diagnostic_search_range,
1975                        cursor_point,
1976                        &project,
1977                        cx,
1978                    )
1979                    .await?
1980                    else {
1981                        return anyhow::Ok(None);
1982                    };
1983
1984                    let Some(prediction_result) = this
1985                        .update(cx, |this, cx| {
1986                            this.request_prediction(
1987                                &project,
1988                                &jump_buffer,
1989                                jump_position,
1990                                PredictEditsRequestTrigger::Diagnostics,
1991                                cx,
1992                            )
1993                        })?
1994                        .await?
1995                    else {
1996                        return anyhow::Ok(None);
1997                    };
1998
1999                    this.update(cx, |this, cx| {
2000                        Some((
2001                            if this
2002                                .get_or_init_project(&project, cx)
2003                                .current_prediction
2004                                .is_none()
2005                            {
2006                                prediction_result
2007                            } else {
2008                                EditPredictionResult {
2009                                    id: prediction_result.id,
2010                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2011                                }
2012                            },
2013                            PredictionRequestedBy::DiagnosticsUpdate,
2014                        ))
2015                    })
2016                })
2017            },
2018        );
2019    }
2020
2021    fn predictions_enabled_at(
2022        snapshot: &BufferSnapshot,
2023        position: Option<language::Anchor>,
2024        cx: &App,
2025    ) -> bool {
2026        let file = snapshot.file();
2027        let all_settings = all_language_settings(file, cx);
2028        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2029            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2030        {
2031            return false;
2032        }
2033
2034        if let Some(last_position) = position {
2035            let settings = snapshot.settings_at(last_position, cx);
2036
2037            if !settings.edit_predictions_disabled_in.is_empty()
2038                && let Some(scope) = snapshot.language_scope_at(last_position)
2039                && let Some(scope_name) = scope.override_name()
2040                && settings
2041                    .edit_predictions_disabled_in
2042                    .iter()
2043                    .any(|s| s == scope_name)
2044            {
2045                return false;
2046            }
2047        }
2048
2049        true
2050    }
2051
2052    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2053}
2054
2055fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2056    let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
2057        return false;
2058    };
2059
2060    app_state
2061        .workspace_store
2062        .read(cx)
2063        .workspaces()
2064        .filter_map(|workspace| workspace.upgrade())
2065        .any(|workspace| {
2066            workspace.read(cx).project().entity_id() == project.entity_id()
2067                && workspace
2068                    .read(cx)
2069                    .leader_for_pane(workspace.read(cx).active_pane())
2070                    .is_some()
2071        })
2072}
2073
2074fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2075    match provider {
2076        EditPredictionProvider::Zed
2077        | EditPredictionProvider::Sweep
2078        | EditPredictionProvider::Mercury
2079        | EditPredictionProvider::Ollama
2080        | EditPredictionProvider::OpenAiCompatibleApi
2081        | EditPredictionProvider::Experimental(_) => true,
2082        EditPredictionProvider::None
2083        | EditPredictionProvider::Copilot
2084        | EditPredictionProvider::Codestral => false,
2085    }
2086}
2087
2088impl EditPredictionStore {
2089    fn queue_prediction_refresh(
2090        &mut self,
2091        project: Entity<Project>,
2092        request_trigger: PredictEditsRequestTrigger,
2093        throttle_entity: EntityId,
2094        cx: &mut Context<Self>,
2095        do_refresh: impl FnOnce(
2096            WeakEntity<Self>,
2097            &mut AsyncApp,
2098        )
2099            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2100        + 'static,
2101    ) {
2102        fn select_throttle(
2103            project_state: &mut ProjectState,
2104            request_trigger: PredictEditsRequestTrigger,
2105        ) -> &mut Option<(EntityId, Instant)> {
2106            match request_trigger {
2107                PredictEditsRequestTrigger::Diagnostics => {
2108                    &mut project_state.last_jump_prediction_refresh
2109                }
2110                _ => &mut project_state.last_edit_prediction_refresh,
2111            }
2112        }
2113
2114        let (needs_acceptance_tracking, max_pending_predictions) =
2115            match all_language_settings(None, cx).edit_predictions.provider {
2116                EditPredictionProvider::Zed
2117                | EditPredictionProvider::Sweep
2118                | EditPredictionProvider::Mercury
2119                | EditPredictionProvider::Experimental(_) => (true, 2),
2120                EditPredictionProvider::Ollama => (false, 1),
2121                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2122                EditPredictionProvider::None
2123                | EditPredictionProvider::Copilot
2124                | EditPredictionProvider::Codestral => {
2125                    log::error!("queue_prediction_refresh called with non-store provider");
2126                    return;
2127                }
2128            };
2129
2130        let drop_on_cancel = !needs_acceptance_tracking;
2131        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2132        let project_state = self.get_or_init_project(&project, cx);
2133        let pending_prediction_id = project_state.next_pending_prediction_id;
2134        project_state.next_pending_prediction_id += 1;
2135        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2136
2137        let task = cx.spawn(async move |this, cx| {
2138            let throttle_wait = this
2139                .update(cx, |this, cx| {
2140                    let project_state = this.get_or_init_project(&project, cx);
2141                    let throttle = *select_throttle(project_state, request_trigger);
2142
2143                    throttle.and_then(|(last_entity, last_timestamp)| {
2144                        if throttle_entity != last_entity {
2145                            return None;
2146                        }
2147                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2148                    })
2149                })
2150                .ok()
2151                .flatten();
2152
2153            if let Some(timeout) = throttle_wait {
2154                cx.background_executor().timer(timeout).await;
2155            }
2156
2157            // If this task was cancelled before the throttle timeout expired,
2158            // do not perform a request. Also skip if another task already
2159            // proceeded since we were enqueued (duplicate).
2160            let mut is_cancelled = true;
2161            this.update(cx, |this, cx| {
2162                let project_state = this.get_or_init_project(&project, cx);
2163                let was_cancelled = project_state
2164                    .cancelled_predictions
2165                    .remove(&pending_prediction_id);
2166                if was_cancelled {
2167                    return;
2168                }
2169
2170                // Another request has been already sent since this was enqueued
2171                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2172                    return;
2173                }
2174
2175                let new_refresh = (throttle_entity, Instant::now());
2176                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2177                is_cancelled = false;
2178            })
2179            .ok();
2180            if is_cancelled {
2181                return None;
2182            }
2183
2184            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2185            let new_prediction_id = new_prediction_result
2186                .as_ref()
2187                .map(|(prediction, _)| prediction.id.clone());
2188
2189            // When a prediction completes, remove it from the pending list, and cancel
2190            // any pending predictions that were enqueued before it.
2191            this.update(cx, |this, cx| {
2192                let project_state = this.get_or_init_project(&project, cx);
2193
2194                let is_cancelled = project_state
2195                    .cancelled_predictions
2196                    .remove(&pending_prediction_id);
2197
2198                let new_current_prediction = if !is_cancelled
2199                    && let Some((prediction_result, requested_by)) = new_prediction_result
2200                {
2201                    match prediction_result.prediction {
2202                        Ok(prediction) => {
2203                            let new_prediction = CurrentEditPrediction {
2204                                requested_by,
2205                                prediction,
2206                                was_shown: false,
2207                                shown_with: None,
2208                            };
2209
2210                            if let Some(current_prediction) =
2211                                project_state.current_prediction.as_ref()
2212                            {
2213                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2214                                {
2215                                    this.reject_current_prediction(
2216                                        EditPredictionRejectReason::Replaced,
2217                                        &project,
2218                                        cx,
2219                                    );
2220
2221                                    Some(new_prediction)
2222                                } else {
2223                                    this.reject_prediction(
2224                                        new_prediction.prediction.id,
2225                                        EditPredictionRejectReason::CurrentPreferred,
2226                                        false,
2227                                        new_prediction.prediction.model_version,
2228                                        cx,
2229                                    );
2230                                    None
2231                                }
2232                            } else {
2233                                Some(new_prediction)
2234                            }
2235                        }
2236                        Err(reject_reason) => {
2237                            this.reject_prediction(
2238                                prediction_result.id,
2239                                reject_reason,
2240                                false,
2241                                None,
2242                                cx,
2243                            );
2244                            None
2245                        }
2246                    }
2247                } else {
2248                    None
2249                };
2250
2251                let project_state = this.get_or_init_project(&project, cx);
2252
2253                if let Some(new_prediction) = new_current_prediction {
2254                    project_state.current_prediction = Some(new_prediction);
2255                }
2256
2257                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2258                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2259                    if pending_prediction.id == pending_prediction_id {
2260                        pending_predictions.remove(ix);
2261                        for pending_prediction in pending_predictions.drain(0..ix) {
2262                            project_state.cancel_pending_prediction(pending_prediction, cx)
2263                        }
2264                        break;
2265                    }
2266                }
2267                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2268                cx.notify();
2269            })
2270            .ok();
2271
2272            new_prediction_id
2273        });
2274
2275        if project_state.pending_predictions.len() < max_pending_predictions {
2276            project_state.pending_predictions.push(PendingPrediction {
2277                id: pending_prediction_id,
2278                task,
2279                drop_on_cancel,
2280            });
2281        } else {
2282            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2283            project_state.pending_predictions.push(PendingPrediction {
2284                id: pending_prediction_id,
2285                task,
2286                drop_on_cancel,
2287            });
2288            project_state.cancel_pending_prediction(pending_prediction, cx);
2289        }
2290    }
2291
2292    pub fn request_prediction(
2293        &mut self,
2294        project: &Entity<Project>,
2295        active_buffer: &Entity<Buffer>,
2296        position: language::Anchor,
2297        trigger: PredictEditsRequestTrigger,
2298        cx: &mut Context<Self>,
2299    ) -> Task<Result<Option<EditPredictionResult>>> {
2300        self.request_prediction_internal(
2301            project.clone(),
2302            active_buffer.clone(),
2303            position,
2304            trigger,
2305            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2306            cx,
2307        )
2308    }
2309
2310    fn request_prediction_internal(
2311        &mut self,
2312        project: Entity<Project>,
2313        active_buffer: Entity<Buffer>,
2314        position: language::Anchor,
2315        trigger: PredictEditsRequestTrigger,
2316        allow_jump: bool,
2317        cx: &mut Context<Self>,
2318    ) -> Task<Result<Option<EditPredictionResult>>> {
2319        self.get_or_init_project(&project, cx);
2320        let project_state = self.projects.get(&project.entity_id()).unwrap();
2321        let stored_events = project_state.events(cx);
2322        let has_events = !stored_events.is_empty();
2323        let events: Vec<Arc<zeta_prompt::Event>> =
2324            stored_events.iter().map(|e| e.event.clone()).collect();
2325        let debug_tx = project_state.debug_tx.clone();
2326
2327        let snapshot = active_buffer.read(cx).snapshot();
2328        let cursor_point = position.to_point(&snapshot);
2329        let current_offset = position.to_offset(&snapshot);
2330
2331        let mut user_actions: Vec<UserActionRecord> =
2332            project_state.user_actions.iter().cloned().collect();
2333
2334        if let Some(last_action) = user_actions.last() {
2335            if last_action.buffer_id == active_buffer.entity_id()
2336                && current_offset != last_action.offset
2337            {
2338                let timestamp_epoch_ms = SystemTime::now()
2339                    .duration_since(UNIX_EPOCH)
2340                    .map(|d| d.as_millis() as u64)
2341                    .unwrap_or(0);
2342                user_actions.push(UserActionRecord {
2343                    action_type: UserActionType::CursorMovement,
2344                    buffer_id: active_buffer.entity_id(),
2345                    line_number: cursor_point.row,
2346                    offset: current_offset,
2347                    timestamp_epoch_ms,
2348                });
2349            }
2350        }
2351        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2352        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2353        let diagnostic_search_range =
2354            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2355
2356        let related_files = self.context_for_project(&project, cx);
2357
2358        let is_open_source = snapshot
2359            .file()
2360            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2361            && events.iter().all(|event| event.in_open_source_repo())
2362            && related_files.iter().all(|file| file.in_open_source_repo);
2363
2364        let can_collect_data = !cfg!(test)
2365            && is_open_source
2366            && self.is_data_collection_enabled(cx)
2367            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2368
2369        let recent_paths = project_state.recent_paths.clone();
2370
2371        let inputs = EditPredictionModelInput {
2372            project: project.clone(),
2373            buffer: active_buffer,
2374            snapshot,
2375            position,
2376            events,
2377            related_files,
2378            recent_paths,
2379            trigger,
2380            diagnostic_search_range: diagnostic_search_range,
2381            debug_tx,
2382            user_actions,
2383            can_collect_data,
2384            is_open_source,
2385        };
2386
2387        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2388
2389        let task = match self.edit_prediction_model {
2390            EditPredictionModel::Zeta => {
2391                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2392            }
2393            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2394            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2395            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2396        };
2397
2398        cx.spawn(async move |this, cx| {
2399            let prediction = task.await?;
2400
2401            // Only fall back to diagnostics-based prediction if we got a
2402            // the model had nothing to suggest for the buffer
2403            if prediction.is_none()
2404                && allow_jump
2405                && has_events
2406                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2407            {
2408                this.update(cx, |this, cx| {
2409                    this.refresh_prediction_from_diagnostics(
2410                        project,
2411                        DiagnosticSearchScope::Local,
2412                        cx,
2413                    );
2414                })?;
2415                return anyhow::Ok(None);
2416            }
2417
2418            Ok(prediction)
2419        })
2420    }
2421
2422    pub(crate) async fn next_diagnostic_location(
2423        active_buffer: Entity<Buffer>,
2424        active_buffer_snapshot: &BufferSnapshot,
2425        active_buffer_diagnostic_search_range: Range<Point>,
2426        active_buffer_cursor_point: Point,
2427        project: &Entity<Project>,
2428        cx: &mut AsyncApp,
2429    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2430        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2431            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2432            .flat_map(|(_, _, _, selections)| {
2433                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2434            })
2435            .collect();
2436
2437        let mut jump_location = active_buffer_snapshot
2438            .diagnostic_groups(None)
2439            .into_iter()
2440            .filter_map(|(_, group)| {
2441                let range = &group.entries[group.primary_ix]
2442                    .range
2443                    .to_point(&active_buffer_snapshot);
2444                if range.overlaps(&active_buffer_diagnostic_search_range) {
2445                    return None;
2446                }
2447                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2448                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2449                });
2450                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2451                    <= DIAGNOSTIC_LINES_RANGE;
2452                if near_collaborator && !near_local {
2453                    return None;
2454                }
2455                Some(range.start)
2456            })
2457            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2458            .map(|position| {
2459                (
2460                    active_buffer.clone(),
2461                    active_buffer_snapshot.anchor_before(position),
2462                )
2463            });
2464
2465        if jump_location.is_none() {
2466            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2467                let file = buffer.file()?;
2468
2469                Some(ProjectPath {
2470                    worktree_id: file.worktree_id(cx),
2471                    path: file.path().clone(),
2472                })
2473            });
2474
2475            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2476                project
2477                    .diagnostic_summaries(false, cx)
2478                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2479                    .map(|(path, _, _)| {
2480                        let shared_prefix = path
2481                            .path
2482                            .components()
2483                            .zip(
2484                                active_buffer_path
2485                                    .as_ref()
2486                                    .map(|p| p.path.components())
2487                                    .unwrap_or_default(),
2488                            )
2489                            .take_while(|(a, b)| a == b)
2490                            .count();
2491                        (path, shared_prefix)
2492                    })
2493                    .collect()
2494            });
2495
2496            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2497
2498            for (path, _) in candidates {
2499                let candidate_buffer = project
2500                    .update(cx, |project, cx| project.open_buffer(path, cx))
2501                    .await?;
2502
2503                let (has_collaborators, diagnostic_position) =
2504                    candidate_buffer.read_with(cx, |buffer, _cx| {
2505                        let snapshot = buffer.snapshot();
2506                        let has_collaborators = snapshot
2507                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2508                            .next()
2509                            .is_some();
2510                        let position = buffer
2511                            .buffer_diagnostics(None)
2512                            .into_iter()
2513                            .min_by_key(|entry| entry.diagnostic.severity)
2514                            .map(|entry| entry.range.start);
2515                        (has_collaborators, position)
2516                    });
2517
2518                if has_collaborators {
2519                    continue;
2520                }
2521
2522                if let Some(position) = diagnostic_position {
2523                    jump_location = Some((candidate_buffer, position));
2524                    break;
2525                }
2526            }
2527        }
2528
2529        anyhow::Ok(jump_location)
2530    }
2531
2532    async fn send_raw_llm_request(
2533        request: RawCompletionRequest,
2534        client: Arc<Client>,
2535        custom_url: Option<Arc<Url>>,
2536        llm_token: LlmApiToken,
2537        organization_id: Option<OrganizationId>,
2538        app_version: Version,
2539    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2540        let url = if let Some(custom_url) = custom_url {
2541            custom_url.as_ref().clone()
2542        } else {
2543            client
2544                .http_client()
2545                .build_zed_llm_url("/predict_edits/raw", &[])?
2546        };
2547
2548        Self::send_api_request(
2549            |builder| {
2550                let req = builder
2551                    .uri(url.as_ref())
2552                    .body(serde_json::to_string(&request)?.into());
2553                Ok(req?)
2554            },
2555            client,
2556            llm_token,
2557            organization_id,
2558            app_version,
2559            true,
2560        )
2561        .await
2562    }
2563
2564    pub(crate) async fn send_v3_request(
2565        input: ZetaPromptInput,
2566        client: Arc<Client>,
2567        llm_token: LlmApiToken,
2568        organization_id: Option<OrganizationId>,
2569        app_version: Version,
2570        trigger: PredictEditsRequestTrigger,
2571    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2572        let url = client
2573            .http_client()
2574            .build_zed_llm_url("/predict_edits/v3", &[])?;
2575
2576        let request = PredictEditsV3Request { input, trigger };
2577
2578        let json_bytes = serde_json::to_vec(&request)?;
2579        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2580
2581        Self::send_api_request(
2582            |builder| {
2583                let req = builder
2584                    .uri(url.as_ref())
2585                    .header("Content-Encoding", "zstd")
2586                    .body(compressed.clone().into());
2587                Ok(req?)
2588            },
2589            client,
2590            llm_token,
2591            organization_id,
2592            app_version,
2593            true,
2594        )
2595        .await
2596    }
2597
2598    async fn send_api_request<Res>(
2599        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2600        client: Arc<Client>,
2601        llm_token: LlmApiToken,
2602        organization_id: Option<OrganizationId>,
2603        app_version: Version,
2604        require_auth: bool,
2605    ) -> Result<(Res, Option<EditPredictionUsage>)>
2606    where
2607        Res: DeserializeOwned,
2608    {
2609        let http_client = client.http_client();
2610
2611        let mut token = if require_auth {
2612            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2613        } else {
2614            llm_token
2615                .acquire(&client, organization_id.clone())
2616                .await
2617                .ok()
2618        };
2619        let mut did_retry = false;
2620
2621        loop {
2622            let request_builder = http_client::Request::builder().method(Method::POST);
2623
2624            let mut request_builder = request_builder
2625                .header("Content-Type", "application/json")
2626                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2627
2628            // Only add Authorization header if we have a token
2629            if let Some(ref token_value) = token {
2630                request_builder =
2631                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2632            }
2633
2634            let request = build(request_builder)?;
2635
2636            let mut response = http_client.send(request).await?;
2637
2638            if let Some(minimum_required_version) = response
2639                .headers()
2640                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2641                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2642            {
2643                anyhow::ensure!(
2644                    app_version >= minimum_required_version,
2645                    ZedUpdateRequiredError {
2646                        minimum_version: minimum_required_version
2647                    }
2648                );
2649            }
2650
2651            if response.status().is_success() {
2652                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2653
2654                let mut body = Vec::new();
2655                response.body_mut().read_to_end(&mut body).await?;
2656                return Ok((serde_json::from_slice(&body)?, usage));
2657            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2658                did_retry = true;
2659                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2660            } else {
2661                let mut body = String::new();
2662                response.body_mut().read_to_string(&mut body).await?;
2663                anyhow::bail!(
2664                    "Request failed with status: {:?}\nBody: {}",
2665                    response.status(),
2666                    body
2667                );
2668            }
2669        }
2670    }
2671
2672    pub fn refresh_context(
2673        &mut self,
2674        project: &Entity<Project>,
2675        buffer: &Entity<language::Buffer>,
2676        cursor_position: language::Anchor,
2677        cx: &mut Context<Self>,
2678    ) {
2679        self.get_or_init_project(project, cx)
2680            .context
2681            .update(cx, |store, cx| {
2682                store.refresh(buffer.clone(), cursor_position, cx);
2683            });
2684    }
2685
2686    #[cfg(feature = "cli-support")]
2687    pub fn set_context_for_buffer(
2688        &mut self,
2689        project: &Entity<Project>,
2690        related_files: Vec<RelatedFile>,
2691        cx: &mut Context<Self>,
2692    ) {
2693        self.get_or_init_project(project, cx)
2694            .context
2695            .update(cx, |store, cx| {
2696                store.set_related_files(related_files, cx);
2697            });
2698    }
2699
2700    #[cfg(feature = "cli-support")]
2701    pub fn set_recent_paths_for_project(
2702        &mut self,
2703        project: &Entity<Project>,
2704        paths: impl IntoIterator<Item = project::ProjectPath>,
2705        cx: &mut Context<Self>,
2706    ) {
2707        let project_state = self.get_or_init_project(project, cx);
2708        project_state.recent_paths = paths.into_iter().collect();
2709    }
2710
2711    fn is_file_open_source(
2712        &self,
2713        project: &Entity<Project>,
2714        file: &Arc<dyn File>,
2715        cx: &App,
2716    ) -> bool {
2717        if !file.is_local() || file.is_private() {
2718            return false;
2719        }
2720        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2721            return false;
2722        };
2723        project_state
2724            .license_detection_watchers
2725            .get(&file.worktree_id(cx))
2726            .as_ref()
2727            .is_some_and(|watcher| watcher.is_project_open_source())
2728    }
2729
2730    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2731        self.data_collection_choice.is_enabled(cx)
2732    }
2733
2734    fn load_data_collection_choice() -> DataCollectionChoice {
2735        let choice = KEY_VALUE_STORE
2736            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2737            .log_err()
2738            .flatten();
2739
2740        match choice.as_deref() {
2741            Some("true") => DataCollectionChoice::Enabled,
2742            Some("false") => DataCollectionChoice::Disabled,
2743            Some(_) => {
2744                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2745                DataCollectionChoice::NotAnswered
2746            }
2747            None => DataCollectionChoice::NotAnswered,
2748        }
2749    }
2750
2751    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2752        self.data_collection_choice = self.data_collection_choice.toggle();
2753        let new_choice = self.data_collection_choice;
2754        let is_enabled = new_choice.is_enabled(cx);
2755        db::write_and_log(cx, move || {
2756            KEY_VALUE_STORE.write_kvp(
2757                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2758                is_enabled.to_string(),
2759            )
2760        });
2761    }
2762
2763    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2764        self.shown_predictions.iter()
2765    }
2766
2767    pub fn shown_completions_len(&self) -> usize {
2768        self.shown_predictions.len()
2769    }
2770
2771    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2772        self.rated_predictions.contains(id)
2773    }
2774
2775    pub fn rate_prediction(
2776        &mut self,
2777        prediction: &EditPrediction,
2778        rating: EditPredictionRating,
2779        feedback: String,
2780        cx: &mut Context<Self>,
2781    ) {
2782        let organization = self.user_store.read(cx).current_organization();
2783
2784        self.rated_predictions.insert(prediction.id.clone());
2785
2786        cx.background_spawn({
2787            let client = self.client.clone();
2788            let prediction_id = prediction.id.to_string();
2789            let inputs = serde_json::to_value(&prediction.inputs);
2790            let output = prediction
2791                .edit_preview
2792                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2793            async move {
2794                client
2795                    .cloud_client()
2796                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2797                        organization_id: organization.map(|organization| organization.id.clone()),
2798                        request_id: prediction_id,
2799                        rating: match rating {
2800                            EditPredictionRating::Positive => "positive".to_string(),
2801                            EditPredictionRating::Negative => "negative".to_string(),
2802                        },
2803                        inputs: inputs?,
2804                        output,
2805                        feedback,
2806                    })
2807                    .await?;
2808
2809                anyhow::Ok(())
2810            }
2811        })
2812        .detach_and_log_err(cx);
2813
2814        cx.notify();
2815    }
2816}
2817
2818fn collaborator_edit_overlaps_locality_region(
2819    project_state: &ProjectState,
2820    project: &Entity<Project>,
2821    buffer: &Entity<Buffer>,
2822    snapshot: &BufferSnapshot,
2823    edit_range: &Range<Anchor>,
2824    cx: &App,
2825) -> bool {
2826    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2827        return false;
2828    };
2829
2830    if active_buffer.entity_id() != buffer.entity_id() {
2831        return false;
2832    }
2833
2834    let locality_point_range = expand_context_syntactically_then_linewise(
2835        snapshot,
2836        (position..position).to_point(snapshot),
2837        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2838    );
2839    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2840
2841    edit_range.overlaps(&locality_anchor_range, snapshot)
2842}
2843
2844fn merge_trailing_events_if_needed(
2845    events: &mut VecDeque<StoredEvent>,
2846    end_snapshot: &TextBufferSnapshot,
2847    latest_snapshot: &TextBufferSnapshot,
2848    latest_edit_range: &Range<Anchor>,
2849) {
2850    if let Some(last_event) = events.back() {
2851        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2852            return;
2853        }
2854        if !latest_snapshot
2855            .version
2856            .observed_all(&last_event.new_snapshot_version)
2857        {
2858            return;
2859        }
2860    }
2861
2862    let mut next_old_event = None;
2863    let mut mergeable_count = 0;
2864    for old_event in events.iter().rev() {
2865        if let Some(next_old_event) = next_old_event
2866            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2867        {
2868            break;
2869        }
2870        mergeable_count += 1;
2871        next_old_event = Some(old_event);
2872    }
2873
2874    if mergeable_count <= 1 {
2875        return;
2876    }
2877
2878    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2879    let oldest_event = events_to_merge.peek().unwrap();
2880    let oldest_snapshot = oldest_event.old_snapshot.clone();
2881    let newest_snapshot = end_snapshot;
2882    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2883
2884    for event in events.range(events.len() - mergeable_count + 1..) {
2885        merged_edit_range =
2886            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2887    }
2888
2889    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2890        &oldest_snapshot,
2891        newest_snapshot,
2892        &merged_edit_range,
2893    ) {
2894        let merged_event = match oldest_event.event.as_ref() {
2895            zeta_prompt::Event::BufferChange {
2896                old_path,
2897                path,
2898                in_open_source_repo,
2899                ..
2900            } => StoredEvent {
2901                event: Arc::new(zeta_prompt::Event::BufferChange {
2902                    old_path: old_path.clone(),
2903                    path: path.clone(),
2904                    diff,
2905                    in_open_source_repo: *in_open_source_repo,
2906                    predicted: events_to_merge.all(|e| {
2907                        matches!(
2908                            e.event.as_ref(),
2909                            zeta_prompt::Event::BufferChange {
2910                                predicted: true,
2911                                ..
2912                            }
2913                        )
2914                    }),
2915                }),
2916                old_snapshot: oldest_snapshot.clone(),
2917                new_snapshot_version: newest_snapshot.version.clone(),
2918                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2919                    ..newest_snapshot.anchor_before(edit_range.end),
2920            },
2921        };
2922        events.truncate(events.len() - mergeable_count);
2923        events.push_back(merged_event);
2924    }
2925}
2926
2927fn merge_anchor_ranges(
2928    left: &Range<Anchor>,
2929    right: &Range<Anchor>,
2930    snapshot: &TextBufferSnapshot,
2931) -> Range<Anchor> {
2932    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2933        left.start
2934    } else {
2935        right.start
2936    };
2937    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2938        left.end
2939    } else {
2940        right.end
2941    };
2942    start..end
2943}
2944
2945#[derive(Error, Debug)]
2946#[error(
2947    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2948)]
2949pub struct ZedUpdateRequiredError {
2950    minimum_version: Version,
2951}
2952
2953#[derive(Debug, Clone, Copy)]
2954pub enum DataCollectionChoice {
2955    NotAnswered,
2956    Enabled,
2957    Disabled,
2958}
2959
2960impl DataCollectionChoice {
2961    pub fn is_enabled(self, cx: &App) -> bool {
2962        if cx.is_staff() {
2963            return true;
2964        }
2965        match self {
2966            Self::Enabled => true,
2967            Self::NotAnswered | Self::Disabled => false,
2968        }
2969    }
2970
2971    #[must_use]
2972    pub fn toggle(&self) -> DataCollectionChoice {
2973        match self {
2974            Self::Enabled => Self::Disabled,
2975            Self::Disabled => Self::Enabled,
2976            Self::NotAnswered => Self::Enabled,
2977        }
2978    }
2979}
2980
2981impl From<bool> for DataCollectionChoice {
2982    fn from(value: bool) -> Self {
2983        match value {
2984            true => DataCollectionChoice::Enabled,
2985            false => DataCollectionChoice::Disabled,
2986        }
2987    }
2988}
2989
2990struct ZedPredictUpsell;
2991
2992impl Dismissable for ZedPredictUpsell {
2993    const KEY: &'static str = "dismissed-edit-predict-upsell";
2994
2995    fn dismissed() -> bool {
2996        // To make this backwards compatible with older versions of Zed, we
2997        // check if the user has seen the previous Edit Prediction Onboarding
2998        // before, by checking the data collection choice which was written to
2999        // the database once the user clicked on "Accept and Enable"
3000        if KEY_VALUE_STORE
3001            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3002            .log_err()
3003            .is_some_and(|s| s.is_some())
3004        {
3005            return true;
3006        }
3007
3008        KEY_VALUE_STORE
3009            .read_kvp(Self::KEY)
3010            .log_err()
3011            .is_some_and(|s| s.is_some())
3012    }
3013}
3014
3015pub fn should_show_upsell_modal() -> bool {
3016    !ZedPredictUpsell::dismissed()
3017}
3018
3019pub fn init(cx: &mut App) {
3020    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3021        workspace.register_action(
3022            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3023                ZedPredictModal::toggle(
3024                    workspace,
3025                    workspace.user_store().clone(),
3026                    workspace.client().clone(),
3027                    window,
3028                    cx,
3029                )
3030            },
3031        );
3032
3033        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3034            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3035                settings
3036                    .project
3037                    .all_languages
3038                    .edit_predictions
3039                    .get_or_insert_default()
3040                    .provider = Some(EditPredictionProvider::None)
3041            });
3042        });
3043        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3044            EditPredictionStore::try_global(cx).and_then(|store| {
3045                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3046            })
3047        }
3048
3049        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3050            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3051                copilot_ui::initiate_sign_in(copilot, window, cx);
3052            }
3053        });
3054        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3055            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3056                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3057            }
3058        });
3059        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3060            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3061                copilot_ui::initiate_sign_out(copilot, window, cx);
3062            }
3063        });
3064    })
3065    .detach();
3066}