edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  79use crate::example_spec::ExampleSpec;
  80use crate::license_detection::LicenseDetectionWatcher;
  81use crate::mercury::Mercury;
  82use crate::onboarding_modal::ZedPredictModal;
  83pub use crate::prediction::EditPrediction;
  84pub use crate::prediction::EditPredictionId;
  85use crate::prediction::EditPredictionResult;
  86pub use crate::sweep_ai::SweepAi;
  87pub use capture_example::capture_example;
  88pub use language_model::ApiKeyState;
  89pub use telemetry_events::EditPredictionRating;
  90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  91
  92actions!(
  93    edit_prediction,
  94    [
  95        /// Resets the edit prediction onboarding state.
  96        ResetOnboarding,
  97        /// Clears the edit prediction history.
  98        ClearHistory,
  99    ]
 100);
 101
 102/// Maximum number of events to track.
 103const EVENT_COUNT_MAX: usize = 10;
 104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 112
 113pub struct EditPredictionJumpsFeatureFlag;
 114
 115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 116    const NAME: &'static str = "edit_prediction_jumps";
 117}
 118
 119#[derive(Clone)]
 120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 121
 122impl Global for EditPredictionStoreGlobal {}
 123
 124/// Configuration for using the raw Zeta2 endpoint.
 125/// When set, the client uses the raw endpoint and constructs the prompt itself.
 126/// The version is also used as the Baseten environment name (lowercased).
 127#[derive(Clone)]
 128pub struct Zeta2RawConfig {
 129    pub model_id: Option<String>,
 130    pub environment: Option<String>,
 131    pub format: ZetaFormat,
 132}
 133
 134pub struct EditPredictionStore {
 135    client: Arc<Client>,
 136    user_store: Entity<UserStore>,
 137    llm_token: LlmApiToken,
 138    _llm_token_subscription: Subscription,
 139    _fetch_experiments_task: Task<()>,
 140    projects: HashMap<EntityId, ProjectState>,
 141    update_required: bool,
 142    edit_prediction_model: EditPredictionModel,
 143    zeta2_raw_config: Option<Zeta2RawConfig>,
 144    preferred_experiment: Option<String>,
 145    available_experiments: Vec<String>,
 146    pub sweep_ai: SweepAi,
 147    pub mercury: Mercury,
 148    data_collection_choice: DataCollectionChoice,
 149    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 150    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 151    shown_predictions: VecDeque<EditPrediction>,
 152    rated_predictions: HashSet<EditPredictionId>,
 153    #[cfg(test)]
 154    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 155}
 156
 157pub(crate) struct EditPredictionRejectionPayload {
 158    rejection: EditPredictionRejection,
 159    organization_id: Option<OrganizationId>,
 160}
 161
 162#[derive(Copy, Clone, PartialEq, Eq)]
 163pub enum EditPredictionModel {
 164    Zeta,
 165    Fim { format: EditPredictionPromptFormat },
 166    Sweep,
 167    Mercury,
 168}
 169
 170#[derive(Clone)]
 171pub struct EditPredictionModelInput {
 172    project: Entity<Project>,
 173    buffer: Entity<Buffer>,
 174    snapshot: BufferSnapshot,
 175    position: Anchor,
 176    events: Vec<Arc<zeta_prompt::Event>>,
 177    related_files: Vec<RelatedFile>,
 178    recent_paths: VecDeque<ProjectPath>,
 179    trigger: PredictEditsRequestTrigger,
 180    diagnostic_search_range: Range<Point>,
 181    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 182    can_collect_data: bool,
 183    is_open_source: bool,
 184    pub user_actions: Vec<UserActionRecord>,
 185}
 186
 187#[derive(Debug)]
 188pub enum DebugEvent {
 189    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 190    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 191    EditPredictionStarted(EditPredictionStartedDebugEvent),
 192    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 193}
 194
 195#[derive(Debug)]
 196pub struct ContextRetrievalStartedDebugEvent {
 197    pub project_entity_id: EntityId,
 198    pub timestamp: Instant,
 199    pub search_prompt: String,
 200}
 201
 202#[derive(Debug)]
 203pub struct ContextRetrievalFinishedDebugEvent {
 204    pub project_entity_id: EntityId,
 205    pub timestamp: Instant,
 206    pub metadata: Vec<(&'static str, SharedString)>,
 207}
 208
 209#[derive(Debug)]
 210pub struct EditPredictionStartedDebugEvent {
 211    pub buffer: WeakEntity<Buffer>,
 212    pub position: Anchor,
 213    pub prompt: Option<String>,
 214}
 215
 216#[derive(Debug)]
 217pub struct EditPredictionFinishedDebugEvent {
 218    pub buffer: WeakEntity<Buffer>,
 219    pub position: Anchor,
 220    pub model_output: Option<String>,
 221}
 222
 223const USER_ACTION_HISTORY_SIZE: usize = 16;
 224
 225#[derive(Clone, Debug)]
 226pub struct UserActionRecord {
 227    pub action_type: UserActionType,
 228    pub buffer_id: EntityId,
 229    pub line_number: u32,
 230    pub offset: usize,
 231    pub timestamp_epoch_ms: u64,
 232}
 233
 234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 235pub enum UserActionType {
 236    InsertChar,
 237    InsertSelection,
 238    DeleteChar,
 239    DeleteSelection,
 240    CursorMovement,
 241}
 242
 243/// An event with associated metadata for reconstructing buffer state.
 244#[derive(Clone)]
 245pub struct StoredEvent {
 246    pub event: Arc<zeta_prompt::Event>,
 247    pub old_snapshot: TextBufferSnapshot,
 248    pub new_snapshot_version: clock::Global,
 249    pub total_edit_range: Range<Anchor>,
 250}
 251
 252impl StoredEvent {
 253    fn can_merge(
 254        &self,
 255        next_old_event: &StoredEvent,
 256        latest_snapshot: &TextBufferSnapshot,
 257        latest_edit_range: &Range<Anchor>,
 258    ) -> bool {
 259        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 260        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 261            return false;
 262        }
 263        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 264            return false;
 265        }
 266        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 267            return false;
 268        }
 269        if !latest_snapshot
 270            .version
 271            .observed_all(&next_old_event.new_snapshot_version)
 272        {
 273            return false;
 274        }
 275
 276        let a_is_predicted = matches!(
 277            self.event.as_ref(),
 278            zeta_prompt::Event::BufferChange {
 279                predicted: true,
 280                ..
 281            }
 282        );
 283        let b_is_predicted = matches!(
 284            next_old_event.event.as_ref(),
 285            zeta_prompt::Event::BufferChange {
 286                predicted: true,
 287                ..
 288            }
 289        );
 290
 291        // If events come from the same source (both predicted or both manual) then
 292        // we would have coalesced them already.
 293        if a_is_predicted == b_is_predicted {
 294            return false;
 295        }
 296
 297        let left_range = self.total_edit_range.to_point(latest_snapshot);
 298        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 299        let latest_range = latest_edit_range.to_point(latest_snapshot);
 300
 301        // Events near to the latest edit are not merged if their sources differ.
 302        if lines_between_ranges(&left_range, &latest_range)
 303            .min(lines_between_ranges(&right_range, &latest_range))
 304            <= CHANGE_GROUPING_LINE_SPAN
 305        {
 306            return false;
 307        }
 308
 309        // Events that are distant from each other are not merged.
 310        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 311            return false;
 312        }
 313
 314        true
 315    }
 316}
 317
 318fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 319    if left.start > right.end {
 320        return left.start.row - right.end.row;
 321    }
 322    if right.start > left.end {
 323        return right.start.row - left.end.row;
 324    }
 325    0
 326}
 327
 328struct ProjectState {
 329    events: VecDeque<StoredEvent>,
 330    last_event: Option<LastEvent>,
 331    recent_paths: VecDeque<ProjectPath>,
 332    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 333    current_prediction: Option<CurrentEditPrediction>,
 334    next_pending_prediction_id: usize,
 335    pending_predictions: ArrayVec<PendingPrediction, 2>,
 336    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 337    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 338    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 339    cancelled_predictions: HashSet<usize>,
 340    context: Entity<RelatedExcerptStore>,
 341    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 342    user_actions: VecDeque<UserActionRecord>,
 343    _subscriptions: [gpui::Subscription; 2],
 344    copilot: Option<Entity<Copilot>>,
 345}
 346
 347impl ProjectState {
 348    fn record_user_action(&mut self, action: UserActionRecord) {
 349        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 350            self.user_actions.pop_front();
 351        }
 352        self.user_actions.push_back(action);
 353    }
 354
 355    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 356        self.events
 357            .iter()
 358            .cloned()
 359            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 360                let (one, two) = event.split_by_pause();
 361                let one = one.finalize(&self.license_detection_watchers, cx);
 362                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 363                one.into_iter().chain(two)
 364            }))
 365            .collect()
 366    }
 367
 368    fn cancel_pending_prediction(
 369        &mut self,
 370        pending_prediction: PendingPrediction,
 371        cx: &mut Context<EditPredictionStore>,
 372    ) {
 373        self.cancelled_predictions.insert(pending_prediction.id);
 374
 375        if pending_prediction.drop_on_cancel {
 376            drop(pending_prediction.task);
 377        } else {
 378            cx.spawn(async move |this, cx| {
 379                let Some(prediction_id) = pending_prediction.task.await else {
 380                    return;
 381                };
 382
 383                this.update(cx, |this, cx| {
 384                    this.reject_prediction(
 385                        prediction_id,
 386                        EditPredictionRejectReason::Canceled,
 387                        false,
 388                        None,
 389                        cx,
 390                    );
 391                })
 392                .ok();
 393            })
 394            .detach()
 395        }
 396    }
 397
 398    fn active_buffer(
 399        &self,
 400        project: &Entity<Project>,
 401        cx: &App,
 402    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 403        let project = project.read(cx);
 404        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 405        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 406        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 407        Some((active_buffer, registered_buffer.last_position))
 408    }
 409}
 410
 411#[derive(Debug, Clone)]
 412struct CurrentEditPrediction {
 413    pub requested_by: PredictionRequestedBy,
 414    pub prediction: EditPrediction,
 415    pub was_shown: bool,
 416    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 417}
 418
 419impl CurrentEditPrediction {
 420    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 421        let Some(new_edits) = self
 422            .prediction
 423            .interpolate(&self.prediction.buffer.read(cx))
 424        else {
 425            return false;
 426        };
 427
 428        if self.prediction.buffer != old_prediction.prediction.buffer {
 429            return true;
 430        }
 431
 432        let Some(old_edits) = old_prediction
 433            .prediction
 434            .interpolate(&old_prediction.prediction.buffer.read(cx))
 435        else {
 436            return true;
 437        };
 438
 439        let requested_by_buffer_id = self.requested_by.buffer_id();
 440
 441        // This reduces the occurrence of UI thrash from replacing edits
 442        //
 443        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 444        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 445            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 446            && old_edits.len() == 1
 447            && new_edits.len() == 1
 448        {
 449            let (old_range, old_text) = &old_edits[0];
 450            let (new_range, new_text) = &new_edits[0];
 451            new_range == old_range && new_text.starts_with(old_text.as_ref())
 452        } else {
 453            true
 454        }
 455    }
 456}
 457
 458#[derive(Debug, Clone)]
 459enum PredictionRequestedBy {
 460    DiagnosticsUpdate,
 461    Buffer(EntityId),
 462}
 463
 464impl PredictionRequestedBy {
 465    pub fn buffer_id(&self) -> Option<EntityId> {
 466        match self {
 467            PredictionRequestedBy::DiagnosticsUpdate => None,
 468            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 469        }
 470    }
 471}
 472
 473const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 474
 475#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 476pub enum DiagnosticSearchScope {
 477    Local,
 478    Global,
 479}
 480
 481#[derive(Debug)]
 482struct PendingPrediction {
 483    id: usize,
 484    task: Task<Option<EditPredictionId>>,
 485    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 486    /// If false, the task is awaited to completion so rejection can be reported.
 487    drop_on_cancel: bool,
 488}
 489
 490/// A prediction from the perspective of a buffer.
 491#[derive(Debug)]
 492enum BufferEditPrediction<'a> {
 493    Local { prediction: &'a EditPrediction },
 494    Jump { prediction: &'a EditPrediction },
 495}
 496
 497#[cfg(test)]
 498impl std::ops::Deref for BufferEditPrediction<'_> {
 499    type Target = EditPrediction;
 500
 501    fn deref(&self) -> &Self::Target {
 502        match self {
 503            BufferEditPrediction::Local { prediction } => prediction,
 504            BufferEditPrediction::Jump { prediction } => prediction,
 505        }
 506    }
 507}
 508
 509#[derive(Clone)]
 510struct PendingSettledPrediction {
 511    request_id: EditPredictionId,
 512    editable_anchor_range: Range<Anchor>,
 513    example: Option<ExampleSpec>,
 514    enqueued_at: Instant,
 515    last_edit_at: Instant,
 516}
 517
 518struct RegisteredBuffer {
 519    file: Option<Arc<dyn File>>,
 520    snapshot: TextBufferSnapshot,
 521    pending_predictions: Vec<PendingSettledPrediction>,
 522    last_position: Option<Anchor>,
 523    _subscriptions: [gpui::Subscription; 2],
 524}
 525
 526#[derive(Clone)]
 527struct LastEvent {
 528    old_snapshot: TextBufferSnapshot,
 529    new_snapshot: TextBufferSnapshot,
 530    old_file: Option<Arc<dyn File>>,
 531    new_file: Option<Arc<dyn File>>,
 532    latest_edit_range: Range<Anchor>,
 533    total_edit_range: Range<Anchor>,
 534    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 535    predicted: bool,
 536    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 537    last_edit_time: Option<Instant>,
 538}
 539
 540impl LastEvent {
 541    pub fn finalize(
 542        &self,
 543        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 544        cx: &App,
 545    ) -> Option<StoredEvent> {
 546        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 547        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 548
 549        let in_open_source_repo =
 550            [self.new_file.as_ref(), self.old_file.as_ref()]
 551                .iter()
 552                .all(|file| {
 553                    file.is_some_and(|file| {
 554                        license_detection_watchers
 555                            .get(&file.worktree_id(cx))
 556                            .is_some_and(|watcher| watcher.is_project_open_source())
 557                    })
 558                });
 559
 560        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 561            &self.old_snapshot,
 562            &self.new_snapshot,
 563            &self.total_edit_range,
 564        )?;
 565
 566        if path == old_path && diff.is_empty() {
 567            None
 568        } else {
 569            Some(StoredEvent {
 570                event: Arc::new(zeta_prompt::Event::BufferChange {
 571                    old_path,
 572                    path,
 573                    diff,
 574                    in_open_source_repo,
 575                    predicted: self.predicted,
 576                }),
 577                old_snapshot: self.old_snapshot.clone(),
 578                new_snapshot_version: self.new_snapshot.version.clone(),
 579                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 580                    ..self.new_snapshot.anchor_before(edit_range.end),
 581            })
 582        }
 583    }
 584
 585    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 586        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 587            return (self.clone(), None);
 588        };
 589
 590        let total_edit_range_before_pause = self
 591            .total_edit_range_at_last_pause_boundary
 592            .clone()
 593            .unwrap_or_else(|| self.total_edit_range.clone());
 594
 595        let Some(total_edit_range_after_pause) =
 596            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 597        else {
 598            return (self.clone(), None);
 599        };
 600
 601        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 602        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 603
 604        let before = LastEvent {
 605            old_snapshot: self.old_snapshot.clone(),
 606            new_snapshot: boundary_snapshot.clone(),
 607            old_file: self.old_file.clone(),
 608            new_file: self.new_file.clone(),
 609            latest_edit_range: latest_edit_range_before_pause,
 610            total_edit_range: total_edit_range_before_pause,
 611            total_edit_range_at_last_pause_boundary: None,
 612            predicted: self.predicted,
 613            snapshot_after_last_editing_pause: None,
 614            last_edit_time: self.last_edit_time,
 615        };
 616
 617        let after = LastEvent {
 618            old_snapshot: boundary_snapshot.clone(),
 619            new_snapshot: self.new_snapshot.clone(),
 620            old_file: self.old_file.clone(),
 621            new_file: self.new_file.clone(),
 622            latest_edit_range: latest_edit_range_after_pause,
 623            total_edit_range: total_edit_range_after_pause,
 624            total_edit_range_at_last_pause_boundary: None,
 625            predicted: self.predicted,
 626            snapshot_after_last_editing_pause: None,
 627            last_edit_time: self.last_edit_time,
 628        };
 629
 630        (before, Some(after))
 631    }
 632}
 633
 634fn compute_total_edit_range_between_snapshots(
 635    old_snapshot: &TextBufferSnapshot,
 636    new_snapshot: &TextBufferSnapshot,
 637) -> Option<Range<Anchor>> {
 638    let edits: Vec<Edit<usize>> = new_snapshot
 639        .edits_since::<usize>(&old_snapshot.version)
 640        .collect();
 641
 642    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 643    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 644    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 645
 646    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 647}
 648
 649fn compute_old_range_for_new_range(
 650    old_snapshot: &TextBufferSnapshot,
 651    new_snapshot: &TextBufferSnapshot,
 652    total_edit_range: &Range<Anchor>,
 653) -> Option<Range<Point>> {
 654    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 655    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 656
 657    let edits: Vec<Edit<usize>> = new_snapshot
 658        .edits_since::<usize>(&old_snapshot.version)
 659        .collect();
 660    let mut old_start_offset = None;
 661    let mut old_end_offset = None;
 662    let mut delta: isize = 0;
 663
 664    for edit in &edits {
 665        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 666            old_start_offset = Some(if new_start_offset < edit.new.start {
 667                new_start_offset.checked_add_signed(-delta)?
 668            } else {
 669                edit.old.start
 670            });
 671        }
 672
 673        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 674            old_end_offset = Some(if new_end_offset < edit.new.start {
 675                new_end_offset.checked_add_signed(-delta)?
 676            } else {
 677                edit.old.end
 678            });
 679        }
 680
 681        delta += edit.new.len() as isize - edit.old.len() as isize;
 682    }
 683
 684    let old_start_offset =
 685        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 686    let old_end_offset =
 687        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 688
 689    Some(
 690        old_snapshot.offset_to_point(old_start_offset)
 691            ..old_snapshot.offset_to_point(old_end_offset),
 692    )
 693}
 694
 695fn compute_diff_between_snapshots_in_range(
 696    old_snapshot: &TextBufferSnapshot,
 697    new_snapshot: &TextBufferSnapshot,
 698    total_edit_range: &Range<Anchor>,
 699) -> Option<(String, Range<Point>)> {
 700    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 701    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 702    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 703    let old_start_point = old_range.start;
 704    let old_end_point = old_range.end;
 705
 706    const CONTEXT_LINES: u32 = 3;
 707
 708    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 709    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 710    let old_context_end_row =
 711        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 712    let new_context_end_row =
 713        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 714
 715    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 716    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 717    let old_end_line_offset = old_snapshot
 718        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 719    let new_end_line_offset = new_snapshot
 720        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 721    let old_edit_range = old_start_line_offset..old_end_line_offset;
 722    let new_edit_range = new_start_line_offset..new_end_line_offset;
 723
 724    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 725    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 726
 727    let diff = language::unified_diff_with_offsets(
 728        &old_region_text,
 729        &new_region_text,
 730        old_context_start_row,
 731        new_context_start_row,
 732    );
 733
 734    Some((diff, new_start_point..new_end_point))
 735}
 736
 737fn buffer_path_with_id_fallback(
 738    file: Option<&Arc<dyn File>>,
 739    snapshot: &TextBufferSnapshot,
 740    cx: &App,
 741) -> Arc<Path> {
 742    if let Some(file) = file {
 743        file.full_path(cx).into()
 744    } else {
 745        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 746    }
 747}
 748
 749impl EditPredictionStore {
 750    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 751        cx.try_global::<EditPredictionStoreGlobal>()
 752            .map(|global| global.0.clone())
 753    }
 754
 755    pub fn global(
 756        client: &Arc<Client>,
 757        user_store: &Entity<UserStore>,
 758        cx: &mut App,
 759    ) -> Entity<Self> {
 760        cx.try_global::<EditPredictionStoreGlobal>()
 761            .map(|global| global.0.clone())
 762            .unwrap_or_else(|| {
 763                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 764                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 765                ep_store
 766            })
 767    }
 768
 769    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 770        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 771        let data_collection_choice = Self::load_data_collection_choice();
 772
 773        let llm_token = LlmApiToken::default();
 774
 775        let (reject_tx, reject_rx) = mpsc::unbounded();
 776        cx.background_spawn({
 777            let client = client.clone();
 778            let llm_token = llm_token.clone();
 779            let app_version = AppVersion::global(cx);
 780            let background_executor = cx.background_executor().clone();
 781            async move {
 782                Self::handle_rejected_predictions(
 783                    reject_rx,
 784                    client,
 785                    llm_token,
 786                    app_version,
 787                    background_executor,
 788                )
 789                .await
 790            }
 791        })
 792        .detach();
 793
 794        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 795        cx.spawn(async move |this, cx| {
 796            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 797        })
 798        .detach();
 799
 800        let mut current_user = user_store.read(cx).watch_current_user();
 801        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 802            while current_user.borrow().is_none() {
 803                current_user.next().await;
 804            }
 805            this.update(cx, |this, cx| {
 806                this.refresh_available_experiments(cx);
 807            })
 808            .log_err();
 809        });
 810
 811        let this = Self {
 812            projects: HashMap::default(),
 813            client,
 814            user_store,
 815            llm_token,
 816            _fetch_experiments_task: fetch_experiments_task,
 817            _llm_token_subscription: cx.subscribe(
 818                &refresh_llm_token_listener,
 819                |this, _listener, _event, cx| {
 820                    let client = this.client.clone();
 821                    let llm_token = this.llm_token.clone();
 822                    let organization_id = this
 823                        .user_store
 824                        .read(cx)
 825                        .current_organization()
 826                        .map(|organization| organization.id.clone());
 827                    cx.spawn(async move |_this, _cx| {
 828                        llm_token.refresh(&client, organization_id).await?;
 829                        anyhow::Ok(())
 830                    })
 831                    .detach_and_log_err(cx);
 832                },
 833            ),
 834            update_required: false,
 835            edit_prediction_model: EditPredictionModel::Zeta,
 836            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 837            preferred_experiment: None,
 838            available_experiments: Vec::new(),
 839            sweep_ai: SweepAi::new(cx),
 840            mercury: Mercury::new(cx),
 841
 842            data_collection_choice,
 843            reject_predictions_tx: reject_tx,
 844            settled_predictions_tx,
 845            rated_predictions: Default::default(),
 846            shown_predictions: Default::default(),
 847            #[cfg(test)]
 848            settled_event_callback: None,
 849        };
 850
 851        this
 852    }
 853
 854    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 855        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 856        let format = ZetaFormat::parse(&version_str).ok()?;
 857        let model_id = env::var("ZED_ZETA_MODEL").ok();
 858        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 859        Some(Zeta2RawConfig {
 860            model_id,
 861            environment,
 862            format,
 863        })
 864    }
 865
 866    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 867        self.edit_prediction_model = model;
 868    }
 869
 870    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 871        self.zeta2_raw_config = Some(config);
 872    }
 873
 874    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 875        self.zeta2_raw_config.as_ref()
 876    }
 877
 878    pub fn preferred_experiment(&self) -> Option<&str> {
 879        self.preferred_experiment.as_deref()
 880    }
 881
 882    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 883        self.preferred_experiment = experiment;
 884    }
 885
 886    pub fn available_experiments(&self) -> &[String] {
 887        &self.available_experiments
 888    }
 889
 890    pub fn active_experiment(&self) -> Option<&str> {
 891        self.preferred_experiment.as_deref().or_else(|| {
 892            self.shown_predictions
 893                .iter()
 894                .find_map(|p| p.model_version.as_ref())
 895                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 896        })
 897    }
 898
 899    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 900        let client = self.client.clone();
 901        let llm_token = self.llm_token.clone();
 902        let app_version = AppVersion::global(cx);
 903        let organization_id = self
 904            .user_store
 905            .read(cx)
 906            .current_organization()
 907            .map(|organization| organization.id.clone());
 908
 909        cx.spawn(async move |this, cx| {
 910            let experiments = cx
 911                .background_spawn(async move {
 912                    let http_client = client.http_client();
 913                    let token = llm_token.acquire(&client, organization_id).await?;
 914                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 915                    let request = http_client::Request::builder()
 916                        .method(Method::GET)
 917                        .uri(url.as_ref())
 918                        .header("Authorization", format!("Bearer {}", token))
 919                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 920                        .body(Default::default())?;
 921                    let mut response = http_client.send(request).await?;
 922                    if response.status().is_success() {
 923                        let mut body = Vec::new();
 924                        response.body_mut().read_to_end(&mut body).await?;
 925                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 926                        Ok(experiments)
 927                    } else {
 928                        let mut body = String::new();
 929                        response.body_mut().read_to_string(&mut body).await?;
 930                        anyhow::bail!(
 931                            "Failed to fetch experiments: {:?}\nBody: {}",
 932                            response.status(),
 933                            body
 934                        );
 935                    }
 936                })
 937                .await?;
 938            this.update(cx, |this, cx| {
 939                this.available_experiments = experiments;
 940                cx.notify();
 941            })?;
 942            anyhow::Ok(())
 943        })
 944        .detach_and_log_err(cx);
 945    }
 946
 947    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 948        use ui::IconName;
 949        match self.edit_prediction_model {
 950            EditPredictionModel::Sweep => {
 951                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 952                    .with_disabled(IconName::SweepAiDisabled)
 953                    .with_up(IconName::SweepAiUp)
 954                    .with_down(IconName::SweepAiDown)
 955                    .with_error(IconName::SweepAiError)
 956            }
 957            EditPredictionModel::Mercury => {
 958                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 959            }
 960            EditPredictionModel::Zeta => {
 961                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 962                    .with_disabled(IconName::ZedPredictDisabled)
 963                    .with_up(IconName::ZedPredictUp)
 964                    .with_down(IconName::ZedPredictDown)
 965                    .with_error(IconName::ZedPredictError)
 966            }
 967            EditPredictionModel::Fim { .. } => {
 968                let settings = &all_language_settings(None, cx).edit_predictions;
 969                match settings.provider {
 970                    EditPredictionProvider::Ollama => {
 971                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 972                    }
 973                    _ => {
 974                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 975                    }
 976                }
 977            }
 978        }
 979    }
 980
 981    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 982        self.sweep_ai.api_token.read(cx).has_key()
 983    }
 984
 985    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 986        self.mercury.api_token.read(cx).has_key()
 987    }
 988
 989    pub fn clear_history(&mut self) {
 990        for project_state in self.projects.values_mut() {
 991            project_state.events.clear();
 992            project_state.last_event.take();
 993        }
 994    }
 995
 996    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 997        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 998            project_state.events.clear();
 999            project_state.last_event.take();
1000        }
1001    }
1002
1003    pub fn edit_history_for_project(
1004        &self,
1005        project: &Entity<Project>,
1006        cx: &App,
1007    ) -> Vec<StoredEvent> {
1008        self.projects
1009            .get(&project.entity_id())
1010            .map(|project_state| project_state.events(cx))
1011            .unwrap_or_default()
1012    }
1013
1014    pub fn context_for_project<'a>(
1015        &'a self,
1016        project: &Entity<Project>,
1017        cx: &'a mut App,
1018    ) -> Vec<RelatedFile> {
1019        self.projects
1020            .get(&project.entity_id())
1021            .map(|project_state| {
1022                project_state.context.update(cx, |context, cx| {
1023                    context
1024                        .related_files_with_buffers(cx)
1025                        .map(|(mut related_file, buffer)| {
1026                            related_file.in_open_source_repo = buffer
1027                                .read(cx)
1028                                .file()
1029                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1030                            related_file
1031                        })
1032                        .collect()
1033                })
1034            })
1035            .unwrap_or_default()
1036    }
1037
1038    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1039        self.projects
1040            .get(&project.entity_id())
1041            .and_then(|project| project.copilot.clone())
1042    }
1043
1044    pub fn start_copilot_for_project(
1045        &mut self,
1046        project: &Entity<Project>,
1047        cx: &mut Context<Self>,
1048    ) -> Option<Entity<Copilot>> {
1049        if DisableAiSettings::get(None, cx).disable_ai {
1050            return None;
1051        }
1052        let state = self.get_or_init_project(project, cx);
1053
1054        if state.copilot.is_some() {
1055            return state.copilot.clone();
1056        }
1057        let _project = project.clone();
1058        let project = project.read(cx);
1059
1060        let node = project.node_runtime().cloned();
1061        if let Some(node) = node {
1062            let next_id = project.languages().next_language_server_id();
1063            let fs = project.fs().clone();
1064
1065            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1066            state.copilot = Some(copilot.clone());
1067            Some(copilot)
1068        } else {
1069            None
1070        }
1071    }
1072
1073    pub fn context_for_project_with_buffers<'a>(
1074        &'a self,
1075        project: &Entity<Project>,
1076        cx: &'a mut App,
1077    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1078        self.projects
1079            .get(&project.entity_id())
1080            .map(|project| {
1081                project.context.update(cx, |context, cx| {
1082                    context.related_files_with_buffers(cx).collect()
1083                })
1084            })
1085            .unwrap_or_default()
1086    }
1087
1088    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1089        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1090            self.user_store.read(cx).edit_prediction_usage()
1091        } else {
1092            None
1093        }
1094    }
1095
1096    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1097        self.get_or_init_project(project, cx);
1098    }
1099
1100    pub fn register_buffer(
1101        &mut self,
1102        buffer: &Entity<Buffer>,
1103        project: &Entity<Project>,
1104        cx: &mut Context<Self>,
1105    ) {
1106        let project_state = self.get_or_init_project(project, cx);
1107        Self::register_buffer_impl(project_state, buffer, project, cx);
1108    }
1109
1110    fn get_or_init_project(
1111        &mut self,
1112        project: &Entity<Project>,
1113        cx: &mut Context<Self>,
1114    ) -> &mut ProjectState {
1115        let entity_id = project.entity_id();
1116        self.projects
1117            .entry(entity_id)
1118            .or_insert_with(|| ProjectState {
1119                context: {
1120                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1121                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1122                        this.handle_excerpt_store_event(entity_id, event);
1123                    })
1124                    .detach();
1125                    related_excerpt_store
1126                },
1127                events: VecDeque::new(),
1128                last_event: None,
1129                recent_paths: VecDeque::new(),
1130                debug_tx: None,
1131                registered_buffers: HashMap::default(),
1132                current_prediction: None,
1133                cancelled_predictions: HashSet::default(),
1134                pending_predictions: ArrayVec::new(),
1135                next_pending_prediction_id: 0,
1136                last_edit_prediction_refresh: None,
1137                last_jump_prediction_refresh: None,
1138                license_detection_watchers: HashMap::default(),
1139                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1140                _subscriptions: [
1141                    cx.subscribe(&project, Self::handle_project_event),
1142                    cx.observe_release(&project, move |this, _, cx| {
1143                        this.projects.remove(&entity_id);
1144                        cx.notify();
1145                    }),
1146                ],
1147                copilot: None,
1148            })
1149    }
1150
1151    pub fn remove_project(&mut self, project: &Entity<Project>) {
1152        self.projects.remove(&project.entity_id());
1153    }
1154
1155    fn handle_excerpt_store_event(
1156        &mut self,
1157        project_entity_id: EntityId,
1158        event: &RelatedExcerptStoreEvent,
1159    ) {
1160        if let Some(project_state) = self.projects.get(&project_entity_id) {
1161            if let Some(debug_tx) = project_state.debug_tx.clone() {
1162                match event {
1163                    RelatedExcerptStoreEvent::StartedRefresh => {
1164                        debug_tx
1165                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1166                                ContextRetrievalStartedDebugEvent {
1167                                    project_entity_id: project_entity_id,
1168                                    timestamp: Instant::now(),
1169                                    search_prompt: String::new(),
1170                                },
1171                            ))
1172                            .ok();
1173                    }
1174                    RelatedExcerptStoreEvent::FinishedRefresh {
1175                        cache_hit_count,
1176                        cache_miss_count,
1177                        mean_definition_latency,
1178                        max_definition_latency,
1179                    } => {
1180                        debug_tx
1181                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1182                                ContextRetrievalFinishedDebugEvent {
1183                                    project_entity_id: project_entity_id,
1184                                    timestamp: Instant::now(),
1185                                    metadata: vec![
1186                                        (
1187                                            "Cache Hits",
1188                                            format!(
1189                                                "{}/{}",
1190                                                cache_hit_count,
1191                                                cache_hit_count + cache_miss_count
1192                                            )
1193                                            .into(),
1194                                        ),
1195                                        (
1196                                            "Max LSP Time",
1197                                            format!("{} ms", max_definition_latency.as_millis())
1198                                                .into(),
1199                                        ),
1200                                        (
1201                                            "Mean LSP Time",
1202                                            format!("{} ms", mean_definition_latency.as_millis())
1203                                                .into(),
1204                                        ),
1205                                    ],
1206                                },
1207                            ))
1208                            .ok();
1209                    }
1210                }
1211            }
1212        }
1213    }
1214
1215    pub fn debug_info(
1216        &mut self,
1217        project: &Entity<Project>,
1218        cx: &mut Context<Self>,
1219    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1220        let project_state = self.get_or_init_project(project, cx);
1221        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1222        project_state.debug_tx = Some(debug_watch_tx);
1223        debug_watch_rx
1224    }
1225
1226    fn handle_project_event(
1227        &mut self,
1228        project: Entity<Project>,
1229        event: &project::Event,
1230        cx: &mut Context<Self>,
1231    ) {
1232        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1233            return;
1234        }
1235        // TODO [zeta2] init with recent paths
1236        match event {
1237            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1238                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1239                    return;
1240                };
1241                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1242                if let Some(path) = path {
1243                    if let Some(ix) = project_state
1244                        .recent_paths
1245                        .iter()
1246                        .position(|probe| probe == &path)
1247                    {
1248                        project_state.recent_paths.remove(ix);
1249                    }
1250                    project_state.recent_paths.push_front(path);
1251                }
1252            }
1253            project::Event::DiagnosticsUpdated { .. } => {
1254                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1255                    self.refresh_prediction_from_diagnostics(
1256                        project,
1257                        DiagnosticSearchScope::Global,
1258                        cx,
1259                    );
1260                }
1261            }
1262            _ => (),
1263        }
1264    }
1265
1266    fn register_buffer_impl<'a>(
1267        project_state: &'a mut ProjectState,
1268        buffer: &Entity<Buffer>,
1269        project: &Entity<Project>,
1270        cx: &mut Context<Self>,
1271    ) -> &'a mut RegisteredBuffer {
1272        let buffer_id = buffer.entity_id();
1273
1274        if let Some(file) = buffer.read(cx).file() {
1275            let worktree_id = file.worktree_id(cx);
1276            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1277                project_state
1278                    .license_detection_watchers
1279                    .entry(worktree_id)
1280                    .or_insert_with(|| {
1281                        let project_entity_id = project.entity_id();
1282                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1283                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1284                            else {
1285                                return;
1286                            };
1287                            project_state
1288                                .license_detection_watchers
1289                                .remove(&worktree_id);
1290                        })
1291                        .detach();
1292                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1293                    });
1294            }
1295        }
1296
1297        match project_state.registered_buffers.entry(buffer_id) {
1298            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1299            hash_map::Entry::Vacant(entry) => {
1300                let buf = buffer.read(cx);
1301                let snapshot = buf.text_snapshot();
1302                let file = buf.file().cloned();
1303                let project_entity_id = project.entity_id();
1304                entry.insert(RegisteredBuffer {
1305                    snapshot,
1306                    file,
1307                    last_position: None,
1308                    pending_predictions: Vec::new(),
1309                    _subscriptions: [
1310                        cx.subscribe(buffer, {
1311                            let project = project.downgrade();
1312                            move |this, buffer, event, cx| {
1313                                if let language::BufferEvent::Edited { is_local } = event
1314                                    && let Some(project) = project.upgrade()
1315                                {
1316                                    this.report_changes_for_buffer(
1317                                        &buffer, &project, false, *is_local, cx,
1318                                    );
1319                                }
1320                            }
1321                        }),
1322                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1323                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1324                            else {
1325                                return;
1326                            };
1327                            project_state.registered_buffers.remove(&buffer_id);
1328                        }),
1329                    ],
1330                })
1331            }
1332        }
1333    }
1334
1335    fn report_changes_for_buffer(
1336        &mut self,
1337        buffer: &Entity<Buffer>,
1338        project: &Entity<Project>,
1339        is_predicted: bool,
1340        is_local: bool,
1341        cx: &mut Context<Self>,
1342    ) {
1343        let project_state = self.get_or_init_project(project, cx);
1344        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1345
1346        let buf = buffer.read(cx);
1347        let new_file = buf.file().cloned();
1348        let new_snapshot = buf.text_snapshot();
1349        if new_snapshot.version == registered_buffer.snapshot.version {
1350            return;
1351        }
1352        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1353        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1354        let mut num_edits = 0usize;
1355        let mut total_deleted = 0usize;
1356        let mut total_inserted = 0usize;
1357        let mut edit_range: Option<Range<Anchor>> = None;
1358        let mut last_offset: Option<usize> = None;
1359        let now = cx.background_executor().now();
1360
1361        for (edit, anchor_range) in
1362            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1363        {
1364            num_edits += 1;
1365            total_deleted += edit.old.len();
1366            total_inserted += edit.new.len();
1367            edit_range = Some(match edit_range {
1368                None => anchor_range,
1369                Some(acc) => acc.start..anchor_range.end,
1370            });
1371            last_offset = Some(edit.new.end);
1372        }
1373
1374        let Some(edit_range) = edit_range else {
1375            return;
1376        };
1377
1378        for pending_prediction in &mut registered_buffer.pending_predictions {
1379            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1380                pending_prediction.last_edit_at = now;
1381            }
1382        }
1383
1384        let include_in_history = is_local
1385            || collaborator_edit_overlaps_locality_region(
1386                project_state,
1387                project,
1388                buffer,
1389                &buf.snapshot(),
1390                &edit_range,
1391                cx,
1392            );
1393
1394        if is_local {
1395            let action_type = match (total_deleted, total_inserted, num_edits) {
1396                (0, ins, n) if ins == n => UserActionType::InsertChar,
1397                (0, _, _) => UserActionType::InsertSelection,
1398                (del, 0, n) if del == n => UserActionType::DeleteChar,
1399                (_, 0, _) => UserActionType::DeleteSelection,
1400                (_, ins, n) if ins == n => UserActionType::InsertChar,
1401                (_, _, _) => UserActionType::InsertSelection,
1402            };
1403
1404            if let Some(offset) = last_offset {
1405                let point = new_snapshot.offset_to_point(offset);
1406                let timestamp_epoch_ms = SystemTime::now()
1407                    .duration_since(UNIX_EPOCH)
1408                    .map(|d| d.as_millis() as u64)
1409                    .unwrap_or(0);
1410                project_state.record_user_action(UserActionRecord {
1411                    action_type,
1412                    buffer_id: buffer.entity_id(),
1413                    line_number: point.row,
1414                    offset,
1415                    timestamp_epoch_ms,
1416                });
1417            }
1418        }
1419
1420        if !include_in_history {
1421            return;
1422        }
1423
1424        let events = &mut project_state.events;
1425
1426        if let Some(last_event) = project_state.last_event.as_mut() {
1427            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1428                == last_event.new_snapshot.remote_id()
1429                && old_snapshot.version == last_event.new_snapshot.version;
1430
1431            let prediction_source_changed = is_predicted != last_event.predicted;
1432
1433            let should_coalesce = is_next_snapshot_of_same_buffer
1434                && !prediction_source_changed
1435                && lines_between_ranges(
1436                    &edit_range.to_point(&new_snapshot),
1437                    &last_event.latest_edit_range.to_point(&new_snapshot),
1438                ) <= CHANGE_GROUPING_LINE_SPAN;
1439
1440            if should_coalesce {
1441                let pause_elapsed = last_event
1442                    .last_edit_time
1443                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1444                    .unwrap_or(false);
1445                if pause_elapsed {
1446                    last_event.snapshot_after_last_editing_pause =
1447                        Some(last_event.new_snapshot.clone());
1448                    last_event.total_edit_range_at_last_pause_boundary =
1449                        Some(last_event.total_edit_range.clone());
1450                }
1451
1452                last_event.latest_edit_range = edit_range.clone();
1453                last_event.total_edit_range =
1454                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1455                last_event.new_snapshot = new_snapshot;
1456                last_event.last_edit_time = Some(now);
1457                return;
1458            }
1459        }
1460
1461        if let Some(event) = project_state.last_event.take() {
1462            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1463                if events.len() + 1 >= EVENT_COUNT_MAX {
1464                    events.pop_front();
1465                }
1466                events.push_back(event);
1467            }
1468        }
1469
1470        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1471
1472        project_state.last_event = Some(LastEvent {
1473            old_file,
1474            new_file,
1475            old_snapshot,
1476            new_snapshot,
1477            latest_edit_range: edit_range.clone(),
1478            total_edit_range: edit_range,
1479            total_edit_range_at_last_pause_boundary: None,
1480            predicted: is_predicted,
1481            snapshot_after_last_editing_pause: None,
1482            last_edit_time: Some(now),
1483        });
1484    }
1485
1486    fn prediction_at(
1487        &mut self,
1488        buffer: &Entity<Buffer>,
1489        position: Option<language::Anchor>,
1490        project: &Entity<Project>,
1491        cx: &App,
1492    ) -> Option<BufferEditPrediction<'_>> {
1493        let project_state = self.projects.get_mut(&project.entity_id())?;
1494        if let Some(position) = position
1495            && let Some(buffer) = project_state
1496                .registered_buffers
1497                .get_mut(&buffer.entity_id())
1498        {
1499            buffer.last_position = Some(position);
1500        }
1501
1502        let CurrentEditPrediction {
1503            requested_by,
1504            prediction,
1505            ..
1506        } = project_state.current_prediction.as_ref()?;
1507
1508        if prediction.targets_buffer(buffer.read(cx)) {
1509            Some(BufferEditPrediction::Local { prediction })
1510        } else {
1511            let show_jump = match requested_by {
1512                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1513                    requested_by_buffer_id == &buffer.entity_id()
1514                }
1515                PredictionRequestedBy::DiagnosticsUpdate => true,
1516            };
1517
1518            if show_jump {
1519                Some(BufferEditPrediction::Jump { prediction })
1520            } else {
1521                None
1522            }
1523        }
1524    }
1525
1526    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1527        let Some(current_prediction) = self
1528            .projects
1529            .get_mut(&project.entity_id())
1530            .and_then(|project_state| project_state.current_prediction.take())
1531        else {
1532            return;
1533        };
1534
1535        self.report_changes_for_buffer(
1536            &current_prediction.prediction.buffer,
1537            project,
1538            true,
1539            true,
1540            cx,
1541        );
1542
1543        // can't hold &mut project_state ref across report_changes_for_buffer_call
1544        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1545            return;
1546        };
1547
1548        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1549            project_state.cancel_pending_prediction(pending_prediction, cx);
1550        }
1551
1552        match self.edit_prediction_model {
1553            EditPredictionModel::Sweep => {
1554                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1555            }
1556            EditPredictionModel::Mercury => {
1557                mercury::edit_prediction_accepted(
1558                    current_prediction.prediction.id,
1559                    self.client.http_client(),
1560                    cx,
1561                );
1562            }
1563            EditPredictionModel::Zeta => {
1564                let is_cloud = !matches!(
1565                    all_language_settings(None, cx).edit_predictions.provider,
1566                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1567                );
1568                if is_cloud {
1569                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1570                }
1571            }
1572            EditPredictionModel::Fim { .. } => {}
1573        }
1574    }
1575
1576    async fn handle_rejected_predictions(
1577        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1578        client: Arc<Client>,
1579        llm_token: LlmApiToken,
1580        app_version: Version,
1581        background_executor: BackgroundExecutor,
1582    ) {
1583        let mut rx = std::pin::pin!(rx.peekable());
1584        let mut batched = Vec::new();
1585
1586        while let Some(EditPredictionRejectionPayload {
1587            rejection,
1588            organization_id,
1589        }) = rx.next().await
1590        {
1591            batched.push(rejection);
1592
1593            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1594                select_biased! {
1595                    next = rx.as_mut().peek().fuse() => {
1596                        if next.is_some() {
1597                            continue;
1598                        }
1599                    }
1600                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1601                }
1602            }
1603
1604            let url = client
1605                .http_client()
1606                .build_zed_llm_url("/predict_edits/reject", &[])
1607                .unwrap();
1608
1609            let flush_count = batched
1610                .len()
1611                // in case items have accumulated after failure
1612                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1613            let start = batched.len() - flush_count;
1614
1615            let body = RejectEditPredictionsBodyRef {
1616                rejections: &batched[start..],
1617            };
1618
1619            let result = Self::send_api_request::<()>(
1620                |builder| {
1621                    let req = builder
1622                        .uri(url.as_ref())
1623                        .body(serde_json::to_string(&body)?.into());
1624                    anyhow::Ok(req?)
1625                },
1626                client.clone(),
1627                llm_token.clone(),
1628                organization_id,
1629                app_version.clone(),
1630                true,
1631            )
1632            .await;
1633
1634            if result.log_err().is_some() {
1635                batched.drain(start..);
1636            }
1637        }
1638    }
1639
1640    async fn run_settled_predictions_worker(
1641        this: WeakEntity<Self>,
1642        mut rx: UnboundedReceiver<Instant>,
1643        cx: &mut AsyncApp,
1644    ) {
1645        let mut next_wake_time: Option<Instant> = None;
1646        loop {
1647            let now = cx.background_executor().now();
1648            if let Some(wake_time) = next_wake_time.take() {
1649                cx.background_executor()
1650                    .timer(wake_time.duration_since(now))
1651                    .await;
1652            } else {
1653                let Some(new_enqueue_time) = rx.next().await else {
1654                    break;
1655                };
1656                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1657                while rx.next().now_or_never().flatten().is_some() {}
1658                continue;
1659            }
1660
1661            let Some(this) = this.upgrade() else {
1662                break;
1663            };
1664
1665            let now = cx.background_executor().now();
1666
1667            let mut oldest_edited_at = None;
1668
1669            this.update(cx, |this, _| {
1670                for (_, project_state) in this.projects.iter_mut() {
1671                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1672                        registered_buffer
1673                            .pending_predictions
1674                            .retain_mut(|pending_prediction| {
1675                                let age =
1676                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1677                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1678                                    return false;
1679                                }
1680
1681                                let quiet_for =
1682                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1683                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1684                                    let settled_editable_region = registered_buffer
1685                                        .snapshot
1686                                        .text_for_range(
1687                                            pending_prediction.editable_anchor_range.clone(),
1688                                        )
1689                                        .collect::<String>();
1690
1691                                    #[cfg(test)]
1692                                    if let Some(callback) = &this.settled_event_callback {
1693                                        callback(
1694                                            pending_prediction.request_id.clone(),
1695                                            settled_editable_region.clone(),
1696                                        );
1697                                    }
1698
1699                                    telemetry::event!(
1700                                        EDIT_PREDICTION_SETTLED_EVENT,
1701                                        request_id = pending_prediction.request_id.0.clone(),
1702                                        settled_editable_region,
1703                                        example = pending_prediction.example.take(),
1704                                    );
1705
1706                                    return false;
1707                                }
1708
1709                                if oldest_edited_at
1710                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1711                                {
1712                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1713                                }
1714
1715                                true
1716                            });
1717                    }
1718                }
1719            });
1720
1721            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1722        }
1723    }
1724
1725    pub(crate) fn enqueue_settled_prediction(
1726        &mut self,
1727        request_id: EditPredictionId,
1728        project: &Entity<Project>,
1729        edited_buffer: &Entity<Buffer>,
1730        edited_buffer_snapshot: &BufferSnapshot,
1731        editable_offset_range: Range<usize>,
1732        example: Option<ExampleSpec>,
1733        cx: &mut Context<Self>,
1734    ) {
1735        let this = &mut *self;
1736        let project_state = this.get_or_init_project(project, cx);
1737        if let Some(buffer) = project_state
1738            .registered_buffers
1739            .get_mut(&edited_buffer.entity_id())
1740        {
1741            let now = cx.background_executor().now();
1742            buffer.pending_predictions.push(PendingSettledPrediction {
1743                request_id: request_id,
1744                editable_anchor_range: edited_buffer_snapshot
1745                    .anchor_range_around(editable_offset_range),
1746                example,
1747                enqueued_at: now,
1748                last_edit_at: now,
1749            });
1750            this.settled_predictions_tx.unbounded_send(now).ok();
1751        }
1752    }
1753
1754    fn reject_current_prediction(
1755        &mut self,
1756        reason: EditPredictionRejectReason,
1757        project: &Entity<Project>,
1758        cx: &App,
1759    ) {
1760        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1761            project_state.pending_predictions.clear();
1762            if let Some(prediction) = project_state.current_prediction.take() {
1763                let model_version = prediction.prediction.model_version.clone();
1764                self.reject_prediction(
1765                    prediction.prediction.id,
1766                    reason,
1767                    prediction.was_shown,
1768                    model_version,
1769                    cx,
1770                );
1771            }
1772        };
1773    }
1774
1775    fn did_show_current_prediction(
1776        &mut self,
1777        project: &Entity<Project>,
1778        display_type: edit_prediction_types::SuggestionDisplayType,
1779        cx: &mut Context<Self>,
1780    ) {
1781        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1782            return;
1783        };
1784
1785        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1786            return;
1787        };
1788
1789        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1790        let previous_shown_with = current_prediction.shown_with;
1791
1792        if previous_shown_with.is_none() || !is_jump {
1793            current_prediction.shown_with = Some(display_type);
1794        }
1795
1796        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1797
1798        if is_first_non_jump_show {
1799            current_prediction.was_shown = true;
1800        }
1801
1802        let display_type_changed = previous_shown_with != Some(display_type);
1803
1804        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1805            sweep_ai::edit_prediction_shown(
1806                &self.sweep_ai,
1807                self.client.clone(),
1808                &current_prediction.prediction,
1809                display_type,
1810                cx,
1811            );
1812        }
1813
1814        if is_first_non_jump_show {
1815            self.shown_predictions
1816                .push_front(current_prediction.prediction.clone());
1817            if self.shown_predictions.len() > 50 {
1818                let completion = self.shown_predictions.pop_back().unwrap();
1819                self.rated_predictions.remove(&completion.id);
1820            }
1821        }
1822    }
1823
1824    fn reject_prediction(
1825        &mut self,
1826        prediction_id: EditPredictionId,
1827        reason: EditPredictionRejectReason,
1828        was_shown: bool,
1829        model_version: Option<String>,
1830        cx: &App,
1831    ) {
1832        match self.edit_prediction_model {
1833            EditPredictionModel::Zeta => {
1834                let is_cloud = !matches!(
1835                    all_language_settings(None, cx).edit_predictions.provider,
1836                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1837                );
1838
1839                if is_cloud {
1840                    let organization_id = self
1841                        .user_store
1842                        .read(cx)
1843                        .current_organization()
1844                        .map(|organization| organization.id.clone());
1845
1846                    self.reject_predictions_tx
1847                        .unbounded_send(EditPredictionRejectionPayload {
1848                            rejection: EditPredictionRejection {
1849                                request_id: prediction_id.to_string(),
1850                                reason,
1851                                was_shown,
1852                                model_version,
1853                            },
1854                            organization_id,
1855                        })
1856                        .log_err();
1857                }
1858            }
1859            EditPredictionModel::Mercury => {
1860                mercury::edit_prediction_rejected(
1861                    prediction_id,
1862                    was_shown,
1863                    reason,
1864                    self.client.http_client(),
1865                    cx,
1866                );
1867            }
1868            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1869        }
1870    }
1871
1872    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1873        self.projects
1874            .get(&project.entity_id())
1875            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1876    }
1877
1878    pub fn refresh_prediction_from_buffer(
1879        &mut self,
1880        project: Entity<Project>,
1881        buffer: Entity<Buffer>,
1882        position: language::Anchor,
1883        cx: &mut Context<Self>,
1884    ) {
1885        self.queue_prediction_refresh(
1886            project.clone(),
1887            PredictEditsRequestTrigger::Other,
1888            buffer.entity_id(),
1889            cx,
1890            move |this, cx| {
1891                let Some(request_task) = this
1892                    .update(cx, |this, cx| {
1893                        this.request_prediction(
1894                            &project,
1895                            &buffer,
1896                            position,
1897                            PredictEditsRequestTrigger::Other,
1898                            cx,
1899                        )
1900                    })
1901                    .log_err()
1902                else {
1903                    return Task::ready(anyhow::Ok(None));
1904                };
1905
1906                cx.spawn(async move |_cx| {
1907                    request_task.await.map(|prediction_result| {
1908                        prediction_result.map(|prediction_result| {
1909                            (
1910                                prediction_result,
1911                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1912                            )
1913                        })
1914                    })
1915                })
1916            },
1917        )
1918    }
1919
1920    pub fn refresh_prediction_from_diagnostics(
1921        &mut self,
1922        project: Entity<Project>,
1923        scope: DiagnosticSearchScope,
1924        cx: &mut Context<Self>,
1925    ) {
1926        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1927            return;
1928        }
1929
1930        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1931            return;
1932        };
1933
1934        // Prefer predictions from buffer
1935        if project_state.current_prediction.is_some() {
1936            log::debug!(
1937                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1938            );
1939            return;
1940        }
1941
1942        self.queue_prediction_refresh(
1943            project.clone(),
1944            PredictEditsRequestTrigger::Diagnostics,
1945            project.entity_id(),
1946            cx,
1947            move |this, cx| {
1948                let Some((active_buffer, snapshot, cursor_point)) = this
1949                    .read_with(cx, |this, cx| {
1950                        let project_state = this.projects.get(&project.entity_id())?;
1951                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1952                        let snapshot = buffer.read(cx).snapshot();
1953
1954                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1955                            return None;
1956                        }
1957
1958                        let cursor_point = position
1959                            .map(|pos| pos.to_point(&snapshot))
1960                            .unwrap_or_default();
1961
1962                        Some((buffer, snapshot, cursor_point))
1963                    })
1964                    .log_err()
1965                    .flatten()
1966                else {
1967                    return Task::ready(anyhow::Ok(None));
1968                };
1969
1970                cx.spawn(async move |cx| {
1971                    let diagnostic_search_range = match scope {
1972                        DiagnosticSearchScope::Local => {
1973                            let diagnostic_search_start =
1974                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1975                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1976                            Point::new(diagnostic_search_start, 0)
1977                                ..Point::new(diagnostic_search_end, 0)
1978                        }
1979                        DiagnosticSearchScope::Global => Default::default(),
1980                    };
1981
1982                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1983                        active_buffer,
1984                        &snapshot,
1985                        diagnostic_search_range,
1986                        cursor_point,
1987                        &project,
1988                        cx,
1989                    )
1990                    .await?
1991                    else {
1992                        return anyhow::Ok(None);
1993                    };
1994
1995                    let Some(prediction_result) = this
1996                        .update(cx, |this, cx| {
1997                            this.request_prediction(
1998                                &project,
1999                                &jump_buffer,
2000                                jump_position,
2001                                PredictEditsRequestTrigger::Diagnostics,
2002                                cx,
2003                            )
2004                        })?
2005                        .await?
2006                    else {
2007                        return anyhow::Ok(None);
2008                    };
2009
2010                    this.update(cx, |this, cx| {
2011                        Some((
2012                            if this
2013                                .get_or_init_project(&project, cx)
2014                                .current_prediction
2015                                .is_none()
2016                            {
2017                                prediction_result
2018                            } else {
2019                                EditPredictionResult {
2020                                    id: prediction_result.id,
2021                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2022                                }
2023                            },
2024                            PredictionRequestedBy::DiagnosticsUpdate,
2025                        ))
2026                    })
2027                })
2028            },
2029        );
2030    }
2031
2032    fn predictions_enabled_at(
2033        snapshot: &BufferSnapshot,
2034        position: Option<language::Anchor>,
2035        cx: &App,
2036    ) -> bool {
2037        let file = snapshot.file();
2038        let all_settings = all_language_settings(file, cx);
2039        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2040            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2041        {
2042            return false;
2043        }
2044
2045        if let Some(last_position) = position {
2046            let settings = snapshot.settings_at(last_position, cx);
2047
2048            if !settings.edit_predictions_disabled_in.is_empty()
2049                && let Some(scope) = snapshot.language_scope_at(last_position)
2050                && let Some(scope_name) = scope.override_name()
2051                && settings
2052                    .edit_predictions_disabled_in
2053                    .iter()
2054                    .any(|s| s == scope_name)
2055            {
2056                return false;
2057            }
2058        }
2059
2060        true
2061    }
2062
2063    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2064}
2065
2066fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2067    match provider {
2068        EditPredictionProvider::Zed
2069        | EditPredictionProvider::Sweep
2070        | EditPredictionProvider::Mercury
2071        | EditPredictionProvider::Ollama
2072        | EditPredictionProvider::OpenAiCompatibleApi
2073        | EditPredictionProvider::Experimental(_) => true,
2074        EditPredictionProvider::None
2075        | EditPredictionProvider::Copilot
2076        | EditPredictionProvider::Codestral => false,
2077    }
2078}
2079
2080impl EditPredictionStore {
2081    fn queue_prediction_refresh(
2082        &mut self,
2083        project: Entity<Project>,
2084        request_trigger: PredictEditsRequestTrigger,
2085        throttle_entity: EntityId,
2086        cx: &mut Context<Self>,
2087        do_refresh: impl FnOnce(
2088            WeakEntity<Self>,
2089            &mut AsyncApp,
2090        )
2091            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2092        + 'static,
2093    ) {
2094        fn select_throttle(
2095            project_state: &mut ProjectState,
2096            request_trigger: PredictEditsRequestTrigger,
2097        ) -> &mut Option<(EntityId, Instant)> {
2098            match request_trigger {
2099                PredictEditsRequestTrigger::Diagnostics => {
2100                    &mut project_state.last_jump_prediction_refresh
2101                }
2102                _ => &mut project_state.last_edit_prediction_refresh,
2103            }
2104        }
2105
2106        let (needs_acceptance_tracking, max_pending_predictions) =
2107            match all_language_settings(None, cx).edit_predictions.provider {
2108                EditPredictionProvider::Zed
2109                | EditPredictionProvider::Sweep
2110                | EditPredictionProvider::Mercury
2111                | EditPredictionProvider::Experimental(_) => (true, 2),
2112                EditPredictionProvider::Ollama => (false, 1),
2113                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2114                EditPredictionProvider::None
2115                | EditPredictionProvider::Copilot
2116                | EditPredictionProvider::Codestral => {
2117                    log::error!("queue_prediction_refresh called with non-store provider");
2118                    return;
2119                }
2120            };
2121
2122        let drop_on_cancel = !needs_acceptance_tracking;
2123        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2124        let project_state = self.get_or_init_project(&project, cx);
2125        let pending_prediction_id = project_state.next_pending_prediction_id;
2126        project_state.next_pending_prediction_id += 1;
2127        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2128
2129        let task = cx.spawn(async move |this, cx| {
2130            let throttle_wait = this
2131                .update(cx, |this, cx| {
2132                    let project_state = this.get_or_init_project(&project, cx);
2133                    let throttle = *select_throttle(project_state, request_trigger);
2134
2135                    throttle.and_then(|(last_entity, last_timestamp)| {
2136                        if throttle_entity != last_entity {
2137                            return None;
2138                        }
2139                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2140                    })
2141                })
2142                .ok()
2143                .flatten();
2144
2145            if let Some(timeout) = throttle_wait {
2146                cx.background_executor().timer(timeout).await;
2147            }
2148
2149            // If this task was cancelled before the throttle timeout expired,
2150            // do not perform a request. Also skip if another task already
2151            // proceeded since we were enqueued (duplicate).
2152            let mut is_cancelled = true;
2153            this.update(cx, |this, cx| {
2154                let project_state = this.get_or_init_project(&project, cx);
2155                let was_cancelled = project_state
2156                    .cancelled_predictions
2157                    .remove(&pending_prediction_id);
2158                if was_cancelled {
2159                    return;
2160                }
2161
2162                // Another request has been already sent since this was enqueued
2163                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2164                    return;
2165                }
2166
2167                let new_refresh = (throttle_entity, Instant::now());
2168                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2169                is_cancelled = false;
2170            })
2171            .ok();
2172            if is_cancelled {
2173                return None;
2174            }
2175
2176            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2177            let new_prediction_id = new_prediction_result
2178                .as_ref()
2179                .map(|(prediction, _)| prediction.id.clone());
2180
2181            // When a prediction completes, remove it from the pending list, and cancel
2182            // any pending predictions that were enqueued before it.
2183            this.update(cx, |this, cx| {
2184                let project_state = this.get_or_init_project(&project, cx);
2185
2186                let is_cancelled = project_state
2187                    .cancelled_predictions
2188                    .remove(&pending_prediction_id);
2189
2190                let new_current_prediction = if !is_cancelled
2191                    && let Some((prediction_result, requested_by)) = new_prediction_result
2192                {
2193                    match prediction_result.prediction {
2194                        Ok(prediction) => {
2195                            let new_prediction = CurrentEditPrediction {
2196                                requested_by,
2197                                prediction,
2198                                was_shown: false,
2199                                shown_with: None,
2200                            };
2201
2202                            if let Some(current_prediction) =
2203                                project_state.current_prediction.as_ref()
2204                            {
2205                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2206                                {
2207                                    this.reject_current_prediction(
2208                                        EditPredictionRejectReason::Replaced,
2209                                        &project,
2210                                        cx,
2211                                    );
2212
2213                                    Some(new_prediction)
2214                                } else {
2215                                    this.reject_prediction(
2216                                        new_prediction.prediction.id,
2217                                        EditPredictionRejectReason::CurrentPreferred,
2218                                        false,
2219                                        new_prediction.prediction.model_version,
2220                                        cx,
2221                                    );
2222                                    None
2223                                }
2224                            } else {
2225                                Some(new_prediction)
2226                            }
2227                        }
2228                        Err(reject_reason) => {
2229                            this.reject_prediction(
2230                                prediction_result.id,
2231                                reject_reason,
2232                                false,
2233                                None,
2234                                cx,
2235                            );
2236                            None
2237                        }
2238                    }
2239                } else {
2240                    None
2241                };
2242
2243                let project_state = this.get_or_init_project(&project, cx);
2244
2245                if let Some(new_prediction) = new_current_prediction {
2246                    project_state.current_prediction = Some(new_prediction);
2247                }
2248
2249                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2250                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2251                    if pending_prediction.id == pending_prediction_id {
2252                        pending_predictions.remove(ix);
2253                        for pending_prediction in pending_predictions.drain(0..ix) {
2254                            project_state.cancel_pending_prediction(pending_prediction, cx)
2255                        }
2256                        break;
2257                    }
2258                }
2259                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2260                cx.notify();
2261            })
2262            .ok();
2263
2264            new_prediction_id
2265        });
2266
2267        if project_state.pending_predictions.len() < max_pending_predictions {
2268            project_state.pending_predictions.push(PendingPrediction {
2269                id: pending_prediction_id,
2270                task,
2271                drop_on_cancel,
2272            });
2273        } else {
2274            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2275            project_state.pending_predictions.push(PendingPrediction {
2276                id: pending_prediction_id,
2277                task,
2278                drop_on_cancel,
2279            });
2280            project_state.cancel_pending_prediction(pending_prediction, cx);
2281        }
2282    }
2283
2284    pub fn request_prediction(
2285        &mut self,
2286        project: &Entity<Project>,
2287        active_buffer: &Entity<Buffer>,
2288        position: language::Anchor,
2289        trigger: PredictEditsRequestTrigger,
2290        cx: &mut Context<Self>,
2291    ) -> Task<Result<Option<EditPredictionResult>>> {
2292        self.request_prediction_internal(
2293            project.clone(),
2294            active_buffer.clone(),
2295            position,
2296            trigger,
2297            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2298            cx,
2299        )
2300    }
2301
2302    fn request_prediction_internal(
2303        &mut self,
2304        project: Entity<Project>,
2305        active_buffer: Entity<Buffer>,
2306        position: language::Anchor,
2307        trigger: PredictEditsRequestTrigger,
2308        allow_jump: bool,
2309        cx: &mut Context<Self>,
2310    ) -> Task<Result<Option<EditPredictionResult>>> {
2311        self.get_or_init_project(&project, cx);
2312        let project_state = self.projects.get(&project.entity_id()).unwrap();
2313        let stored_events = project_state.events(cx);
2314        let has_events = !stored_events.is_empty();
2315        let events: Vec<Arc<zeta_prompt::Event>> =
2316            stored_events.iter().map(|e| e.event.clone()).collect();
2317        let debug_tx = project_state.debug_tx.clone();
2318
2319        let snapshot = active_buffer.read(cx).snapshot();
2320        let cursor_point = position.to_point(&snapshot);
2321        let current_offset = position.to_offset(&snapshot);
2322
2323        let mut user_actions: Vec<UserActionRecord> =
2324            project_state.user_actions.iter().cloned().collect();
2325
2326        if let Some(last_action) = user_actions.last() {
2327            if last_action.buffer_id == active_buffer.entity_id()
2328                && current_offset != last_action.offset
2329            {
2330                let timestamp_epoch_ms = SystemTime::now()
2331                    .duration_since(UNIX_EPOCH)
2332                    .map(|d| d.as_millis() as u64)
2333                    .unwrap_or(0);
2334                user_actions.push(UserActionRecord {
2335                    action_type: UserActionType::CursorMovement,
2336                    buffer_id: active_buffer.entity_id(),
2337                    line_number: cursor_point.row,
2338                    offset: current_offset,
2339                    timestamp_epoch_ms,
2340                });
2341            }
2342        }
2343        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2344        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2345        let diagnostic_search_range =
2346            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2347
2348        let related_files = self.context_for_project(&project, cx);
2349
2350        let is_open_source = snapshot
2351            .file()
2352            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2353            && events.iter().all(|event| event.in_open_source_repo())
2354            && related_files.iter().all(|file| file.in_open_source_repo);
2355
2356        let can_collect_data = !cfg!(test)
2357            && is_open_source
2358            && self.is_data_collection_enabled(cx)
2359            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2360
2361        let recent_paths = project_state.recent_paths.clone();
2362
2363        let inputs = EditPredictionModelInput {
2364            project: project.clone(),
2365            buffer: active_buffer,
2366            snapshot,
2367            position,
2368            events,
2369            related_files,
2370            recent_paths,
2371            trigger,
2372            diagnostic_search_range: diagnostic_search_range,
2373            debug_tx,
2374            user_actions,
2375            can_collect_data,
2376            is_open_source,
2377        };
2378
2379        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2380
2381        let task = match self.edit_prediction_model {
2382            EditPredictionModel::Zeta => {
2383                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2384            }
2385            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2386            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2387            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2388        };
2389
2390        cx.spawn(async move |this, cx| {
2391            let prediction = task.await?;
2392
2393            // Only fall back to diagnostics-based prediction if we got a
2394            // the model had nothing to suggest for the buffer
2395            if prediction.is_none()
2396                && allow_jump
2397                && has_events
2398                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2399            {
2400                this.update(cx, |this, cx| {
2401                    this.refresh_prediction_from_diagnostics(
2402                        project,
2403                        DiagnosticSearchScope::Local,
2404                        cx,
2405                    );
2406                })?;
2407                return anyhow::Ok(None);
2408            }
2409
2410            Ok(prediction)
2411        })
2412    }
2413
2414    pub(crate) async fn next_diagnostic_location(
2415        active_buffer: Entity<Buffer>,
2416        active_buffer_snapshot: &BufferSnapshot,
2417        active_buffer_diagnostic_search_range: Range<Point>,
2418        active_buffer_cursor_point: Point,
2419        project: &Entity<Project>,
2420        cx: &mut AsyncApp,
2421    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2422        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2423            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2424            .flat_map(|(_, _, _, selections)| {
2425                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2426            })
2427            .collect();
2428
2429        let mut jump_location = active_buffer_snapshot
2430            .diagnostic_groups(None)
2431            .into_iter()
2432            .filter_map(|(_, group)| {
2433                let range = &group.entries[group.primary_ix]
2434                    .range
2435                    .to_point(&active_buffer_snapshot);
2436                if range.overlaps(&active_buffer_diagnostic_search_range) {
2437                    return None;
2438                }
2439                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2440                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2441                });
2442                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2443                    <= DIAGNOSTIC_LINES_RANGE;
2444                if near_collaborator && !near_local {
2445                    return None;
2446                }
2447                Some(range.start)
2448            })
2449            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2450            .map(|position| {
2451                (
2452                    active_buffer.clone(),
2453                    active_buffer_snapshot.anchor_before(position),
2454                )
2455            });
2456
2457        if jump_location.is_none() {
2458            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2459                let file = buffer.file()?;
2460
2461                Some(ProjectPath {
2462                    worktree_id: file.worktree_id(cx),
2463                    path: file.path().clone(),
2464                })
2465            });
2466
2467            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2468                project
2469                    .diagnostic_summaries(false, cx)
2470                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2471                    .map(|(path, _, _)| {
2472                        let shared_prefix = path
2473                            .path
2474                            .components()
2475                            .zip(
2476                                active_buffer_path
2477                                    .as_ref()
2478                                    .map(|p| p.path.components())
2479                                    .unwrap_or_default(),
2480                            )
2481                            .take_while(|(a, b)| a == b)
2482                            .count();
2483                        (path, shared_prefix)
2484                    })
2485                    .collect()
2486            });
2487
2488            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2489
2490            for (path, _) in candidates {
2491                let candidate_buffer = project
2492                    .update(cx, |project, cx| project.open_buffer(path, cx))
2493                    .await?;
2494
2495                let (has_collaborators, diagnostic_position) =
2496                    candidate_buffer.read_with(cx, |buffer, _cx| {
2497                        let snapshot = buffer.snapshot();
2498                        let has_collaborators = snapshot
2499                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2500                            .next()
2501                            .is_some();
2502                        let position = buffer
2503                            .buffer_diagnostics(None)
2504                            .into_iter()
2505                            .min_by_key(|entry| entry.diagnostic.severity)
2506                            .map(|entry| entry.range.start);
2507                        (has_collaborators, position)
2508                    });
2509
2510                if has_collaborators {
2511                    continue;
2512                }
2513
2514                if let Some(position) = diagnostic_position {
2515                    jump_location = Some((candidate_buffer, position));
2516                    break;
2517                }
2518            }
2519        }
2520
2521        anyhow::Ok(jump_location)
2522    }
2523
2524    async fn send_raw_llm_request(
2525        request: RawCompletionRequest,
2526        client: Arc<Client>,
2527        custom_url: Option<Arc<Url>>,
2528        llm_token: LlmApiToken,
2529        organization_id: Option<OrganizationId>,
2530        app_version: Version,
2531    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2532        let url = if let Some(custom_url) = custom_url {
2533            custom_url.as_ref().clone()
2534        } else {
2535            client
2536                .http_client()
2537                .build_zed_llm_url("/predict_edits/raw", &[])?
2538        };
2539
2540        Self::send_api_request(
2541            |builder| {
2542                let req = builder
2543                    .uri(url.as_ref())
2544                    .body(serde_json::to_string(&request)?.into());
2545                Ok(req?)
2546            },
2547            client,
2548            llm_token,
2549            organization_id,
2550            app_version,
2551            true,
2552        )
2553        .await
2554    }
2555
2556    pub(crate) async fn send_v3_request(
2557        input: ZetaPromptInput,
2558        client: Arc<Client>,
2559        llm_token: LlmApiToken,
2560        organization_id: Option<OrganizationId>,
2561        app_version: Version,
2562        trigger: PredictEditsRequestTrigger,
2563    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2564        let url = client
2565            .http_client()
2566            .build_zed_llm_url("/predict_edits/v3", &[])?;
2567
2568        let request = PredictEditsV3Request { input, trigger };
2569
2570        let json_bytes = serde_json::to_vec(&request)?;
2571        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2572
2573        Self::send_api_request(
2574            |builder| {
2575                let req = builder
2576                    .uri(url.as_ref())
2577                    .header("Content-Encoding", "zstd")
2578                    .body(compressed.clone().into());
2579                Ok(req?)
2580            },
2581            client,
2582            llm_token,
2583            organization_id,
2584            app_version,
2585            true,
2586        )
2587        .await
2588    }
2589
2590    async fn send_api_request<Res>(
2591        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2592        client: Arc<Client>,
2593        llm_token: LlmApiToken,
2594        organization_id: Option<OrganizationId>,
2595        app_version: Version,
2596        require_auth: bool,
2597    ) -> Result<(Res, Option<EditPredictionUsage>)>
2598    where
2599        Res: DeserializeOwned,
2600    {
2601        let http_client = client.http_client();
2602
2603        let mut token = if require_auth {
2604            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2605        } else {
2606            llm_token
2607                .acquire(&client, organization_id.clone())
2608                .await
2609                .ok()
2610        };
2611        let mut did_retry = false;
2612
2613        loop {
2614            let request_builder = http_client::Request::builder().method(Method::POST);
2615
2616            let mut request_builder = request_builder
2617                .header("Content-Type", "application/json")
2618                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2619
2620            // Only add Authorization header if we have a token
2621            if let Some(ref token_value) = token {
2622                request_builder =
2623                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2624            }
2625
2626            let request = build(request_builder)?;
2627
2628            let mut response = http_client.send(request).await?;
2629
2630            if let Some(minimum_required_version) = response
2631                .headers()
2632                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2633                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2634            {
2635                anyhow::ensure!(
2636                    app_version >= minimum_required_version,
2637                    ZedUpdateRequiredError {
2638                        minimum_version: minimum_required_version
2639                    }
2640                );
2641            }
2642
2643            if response.status().is_success() {
2644                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2645
2646                let mut body = Vec::new();
2647                response.body_mut().read_to_end(&mut body).await?;
2648                return Ok((serde_json::from_slice(&body)?, usage));
2649            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2650                did_retry = true;
2651                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2652            } else {
2653                let mut body = String::new();
2654                response.body_mut().read_to_string(&mut body).await?;
2655                anyhow::bail!(
2656                    "Request failed with status: {:?}\nBody: {}",
2657                    response.status(),
2658                    body
2659                );
2660            }
2661        }
2662    }
2663
2664    pub fn refresh_context(
2665        &mut self,
2666        project: &Entity<Project>,
2667        buffer: &Entity<language::Buffer>,
2668        cursor_position: language::Anchor,
2669        cx: &mut Context<Self>,
2670    ) {
2671        self.get_or_init_project(project, cx)
2672            .context
2673            .update(cx, |store, cx| {
2674                store.refresh(buffer.clone(), cursor_position, cx);
2675            });
2676    }
2677
2678    #[cfg(feature = "cli-support")]
2679    pub fn set_context_for_buffer(
2680        &mut self,
2681        project: &Entity<Project>,
2682        related_files: Vec<RelatedFile>,
2683        cx: &mut Context<Self>,
2684    ) {
2685        self.get_or_init_project(project, cx)
2686            .context
2687            .update(cx, |store, cx| {
2688                store.set_related_files(related_files, cx);
2689            });
2690    }
2691
2692    #[cfg(feature = "cli-support")]
2693    pub fn set_recent_paths_for_project(
2694        &mut self,
2695        project: &Entity<Project>,
2696        paths: impl IntoIterator<Item = project::ProjectPath>,
2697        cx: &mut Context<Self>,
2698    ) {
2699        let project_state = self.get_or_init_project(project, cx);
2700        project_state.recent_paths = paths.into_iter().collect();
2701    }
2702
2703    fn is_file_open_source(
2704        &self,
2705        project: &Entity<Project>,
2706        file: &Arc<dyn File>,
2707        cx: &App,
2708    ) -> bool {
2709        if !file.is_local() || file.is_private() {
2710            return false;
2711        }
2712        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2713            return false;
2714        };
2715        project_state
2716            .license_detection_watchers
2717            .get(&file.worktree_id(cx))
2718            .as_ref()
2719            .is_some_and(|watcher| watcher.is_project_open_source())
2720    }
2721
2722    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2723        self.data_collection_choice.is_enabled(cx)
2724    }
2725
2726    fn load_data_collection_choice() -> DataCollectionChoice {
2727        let choice = KEY_VALUE_STORE
2728            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2729            .log_err()
2730            .flatten();
2731
2732        match choice.as_deref() {
2733            Some("true") => DataCollectionChoice::Enabled,
2734            Some("false") => DataCollectionChoice::Disabled,
2735            Some(_) => {
2736                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2737                DataCollectionChoice::NotAnswered
2738            }
2739            None => DataCollectionChoice::NotAnswered,
2740        }
2741    }
2742
2743    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2744        self.data_collection_choice = self.data_collection_choice.toggle();
2745        let new_choice = self.data_collection_choice;
2746        let is_enabled = new_choice.is_enabled(cx);
2747        db::write_and_log(cx, move || {
2748            KEY_VALUE_STORE.write_kvp(
2749                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2750                is_enabled.to_string(),
2751            )
2752        });
2753    }
2754
2755    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2756        self.shown_predictions.iter()
2757    }
2758
2759    pub fn shown_completions_len(&self) -> usize {
2760        self.shown_predictions.len()
2761    }
2762
2763    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2764        self.rated_predictions.contains(id)
2765    }
2766
2767    pub fn rate_prediction(
2768        &mut self,
2769        prediction: &EditPrediction,
2770        rating: EditPredictionRating,
2771        feedback: String,
2772        cx: &mut Context<Self>,
2773    ) {
2774        let organization = self.user_store.read(cx).current_organization();
2775
2776        self.rated_predictions.insert(prediction.id.clone());
2777
2778        cx.background_spawn({
2779            let client = self.client.clone();
2780            let prediction_id = prediction.id.to_string();
2781            let inputs = serde_json::to_value(&prediction.inputs);
2782            let output = prediction
2783                .edit_preview
2784                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2785            async move {
2786                client
2787                    .cloud_client()
2788                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2789                        organization_id: organization.map(|organization| organization.id.clone()),
2790                        request_id: prediction_id,
2791                        rating: match rating {
2792                            EditPredictionRating::Positive => "positive".to_string(),
2793                            EditPredictionRating::Negative => "negative".to_string(),
2794                        },
2795                        inputs: inputs?,
2796                        output,
2797                        feedback,
2798                    })
2799                    .await?;
2800
2801                anyhow::Ok(())
2802            }
2803        })
2804        .detach_and_log_err(cx);
2805
2806        cx.notify();
2807    }
2808}
2809
2810fn collaborator_edit_overlaps_locality_region(
2811    project_state: &ProjectState,
2812    project: &Entity<Project>,
2813    buffer: &Entity<Buffer>,
2814    snapshot: &BufferSnapshot,
2815    edit_range: &Range<Anchor>,
2816    cx: &App,
2817) -> bool {
2818    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2819        return false;
2820    };
2821
2822    if active_buffer.entity_id() != buffer.entity_id() {
2823        return false;
2824    }
2825
2826    let locality_point_range = expand_context_syntactically_then_linewise(
2827        snapshot,
2828        (position..position).to_point(snapshot),
2829        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2830    );
2831    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2832
2833    edit_range.overlaps(&locality_anchor_range, snapshot)
2834}
2835
2836fn merge_trailing_events_if_needed(
2837    events: &mut VecDeque<StoredEvent>,
2838    end_snapshot: &TextBufferSnapshot,
2839    latest_snapshot: &TextBufferSnapshot,
2840    latest_edit_range: &Range<Anchor>,
2841) {
2842    if let Some(last_event) = events.back() {
2843        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2844            return;
2845        }
2846        if !latest_snapshot
2847            .version
2848            .observed_all(&last_event.new_snapshot_version)
2849        {
2850            return;
2851        }
2852    }
2853
2854    let mut next_old_event = None;
2855    let mut mergeable_count = 0;
2856    for old_event in events.iter().rev() {
2857        if let Some(next_old_event) = next_old_event
2858            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2859        {
2860            break;
2861        }
2862        mergeable_count += 1;
2863        next_old_event = Some(old_event);
2864    }
2865
2866    if mergeable_count <= 1 {
2867        return;
2868    }
2869
2870    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2871    let oldest_event = events_to_merge.peek().unwrap();
2872    let oldest_snapshot = oldest_event.old_snapshot.clone();
2873    let newest_snapshot = end_snapshot;
2874    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2875
2876    for event in events.range(events.len() - mergeable_count + 1..) {
2877        merged_edit_range =
2878            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2879    }
2880
2881    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2882        &oldest_snapshot,
2883        newest_snapshot,
2884        &merged_edit_range,
2885    ) {
2886        let merged_event = match oldest_event.event.as_ref() {
2887            zeta_prompt::Event::BufferChange {
2888                old_path,
2889                path,
2890                in_open_source_repo,
2891                ..
2892            } => StoredEvent {
2893                event: Arc::new(zeta_prompt::Event::BufferChange {
2894                    old_path: old_path.clone(),
2895                    path: path.clone(),
2896                    diff,
2897                    in_open_source_repo: *in_open_source_repo,
2898                    predicted: events_to_merge.all(|e| {
2899                        matches!(
2900                            e.event.as_ref(),
2901                            zeta_prompt::Event::BufferChange {
2902                                predicted: true,
2903                                ..
2904                            }
2905                        )
2906                    }),
2907                }),
2908                old_snapshot: oldest_snapshot.clone(),
2909                new_snapshot_version: newest_snapshot.version.clone(),
2910                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2911                    ..newest_snapshot.anchor_before(edit_range.end),
2912            },
2913        };
2914        events.truncate(events.len() - mergeable_count);
2915        events.push_back(merged_event);
2916    }
2917}
2918
2919fn merge_anchor_ranges(
2920    left: &Range<Anchor>,
2921    right: &Range<Anchor>,
2922    snapshot: &TextBufferSnapshot,
2923) -> Range<Anchor> {
2924    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2925        left.start
2926    } else {
2927        right.start
2928    };
2929    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2930        left.end
2931    } else {
2932        right.end
2933    };
2934    start..end
2935}
2936
2937#[derive(Error, Debug)]
2938#[error(
2939    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2940)]
2941pub struct ZedUpdateRequiredError {
2942    minimum_version: Version,
2943}
2944
2945#[derive(Debug, Clone, Copy)]
2946pub enum DataCollectionChoice {
2947    NotAnswered,
2948    Enabled,
2949    Disabled,
2950}
2951
2952impl DataCollectionChoice {
2953    pub fn is_enabled(self, cx: &App) -> bool {
2954        if cx.is_staff() {
2955            return true;
2956        }
2957        match self {
2958            Self::Enabled => true,
2959            Self::NotAnswered | Self::Disabled => false,
2960        }
2961    }
2962
2963    #[must_use]
2964    pub fn toggle(&self) -> DataCollectionChoice {
2965        match self {
2966            Self::Enabled => Self::Disabled,
2967            Self::Disabled => Self::Enabled,
2968            Self::NotAnswered => Self::Enabled,
2969        }
2970    }
2971}
2972
2973impl From<bool> for DataCollectionChoice {
2974    fn from(value: bool) -> Self {
2975        match value {
2976            true => DataCollectionChoice::Enabled,
2977            false => DataCollectionChoice::Disabled,
2978        }
2979    }
2980}
2981
2982struct ZedPredictUpsell;
2983
2984impl Dismissable for ZedPredictUpsell {
2985    const KEY: &'static str = "dismissed-edit-predict-upsell";
2986
2987    fn dismissed() -> bool {
2988        // To make this backwards compatible with older versions of Zed, we
2989        // check if the user has seen the previous Edit Prediction Onboarding
2990        // before, by checking the data collection choice which was written to
2991        // the database once the user clicked on "Accept and Enable"
2992        if KEY_VALUE_STORE
2993            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2994            .log_err()
2995            .is_some_and(|s| s.is_some())
2996        {
2997            return true;
2998        }
2999
3000        KEY_VALUE_STORE
3001            .read_kvp(Self::KEY)
3002            .log_err()
3003            .is_some_and(|s| s.is_some())
3004    }
3005}
3006
3007pub fn should_show_upsell_modal() -> bool {
3008    !ZedPredictUpsell::dismissed()
3009}
3010
3011pub fn init(cx: &mut App) {
3012    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3013        workspace.register_action(
3014            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3015                ZedPredictModal::toggle(
3016                    workspace,
3017                    workspace.user_store().clone(),
3018                    workspace.client().clone(),
3019                    window,
3020                    cx,
3021                )
3022            },
3023        );
3024
3025        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3026            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3027                settings
3028                    .project
3029                    .all_languages
3030                    .edit_predictions
3031                    .get_or_insert_default()
3032                    .provider = Some(EditPredictionProvider::None)
3033            });
3034        });
3035        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3036            EditPredictionStore::try_global(cx).and_then(|store| {
3037                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3038            })
3039        }
3040
3041        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3042            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3043                copilot_ui::initiate_sign_in(copilot, window, cx);
3044            }
3045        });
3046        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3047            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3048                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3049            }
3050        });
3051        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3052            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3053                copilot_ui::initiate_sign_out(copilot, window, cx);
3054            }
3055        });
3056    })
3057    .detach();
3058}