edit_prediction.rs

   1use anyhow::Result;
   2use client::{Client, EditPredictionUsage, UserStore};
   3use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KeyValueStore};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use heapless::Vec as ArrayVec;
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::{AppState, Workspace};
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  79use crate::example_spec::ExampleSpec;
  80use crate::license_detection::LicenseDetectionWatcher;
  81use crate::mercury::Mercury;
  82use crate::onboarding_modal::ZedPredictModal;
  83pub use crate::prediction::EditPrediction;
  84pub use crate::prediction::EditPredictionId;
  85use crate::prediction::EditPredictionResult;
  86pub use crate::sweep_ai::SweepAi;
  87pub use capture_example::capture_example;
  88pub use language_model::ApiKeyState;
  89pub use telemetry_events::EditPredictionRating;
  90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  91
  92actions!(
  93    edit_prediction,
  94    [
  95        /// Resets the edit prediction onboarding state.
  96        ResetOnboarding,
  97        /// Clears the edit prediction history.
  98        ClearHistory,
  99    ]
 100);
 101
 102/// Maximum number of events to track.
 103const EVENT_COUNT_MAX: usize = 10;
 104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 105const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 106const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 107const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 108const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 109const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 110const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 111const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 112const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 113
 114pub struct EditPredictionJumpsFeatureFlag;
 115
 116impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 117    const NAME: &'static str = "edit_prediction_jumps";
 118}
 119
 120#[derive(Clone)]
 121struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 122
 123impl Global for EditPredictionStoreGlobal {}
 124
 125/// Configuration for using the raw Zeta2 endpoint.
 126/// When set, the client uses the raw endpoint and constructs the prompt itself.
 127/// The version is also used as the Baseten environment name (lowercased).
 128#[derive(Clone)]
 129pub struct Zeta2RawConfig {
 130    pub model_id: Option<String>,
 131    pub environment: Option<String>,
 132    pub format: ZetaFormat,
 133}
 134
 135pub struct EditPredictionStore {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _fetch_experiments_task: Task<()>,
 140    projects: HashMap<EntityId, ProjectState>,
 141    update_required: bool,
 142    edit_prediction_model: EditPredictionModel,
 143    zeta2_raw_config: Option<Zeta2RawConfig>,
 144    preferred_experiment: Option<String>,
 145    available_experiments: Vec<String>,
 146    pub sweep_ai: SweepAi,
 147    pub mercury: Mercury,
 148    data_collection_choice: DataCollectionChoice,
 149    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 150    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 151    shown_predictions: VecDeque<EditPrediction>,
 152    rated_predictions: HashSet<EditPredictionId>,
 153    #[cfg(test)]
 154    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 155}
 156
 157pub(crate) struct EditPredictionRejectionPayload {
 158    rejection: EditPredictionRejection,
 159    organization_id: Option<OrganizationId>,
 160}
 161
 162#[derive(Copy, Clone, PartialEq, Eq)]
 163pub enum EditPredictionModel {
 164    Zeta,
 165    Fim { format: EditPredictionPromptFormat },
 166    Sweep,
 167    Mercury,
 168}
 169
 170#[derive(Clone)]
 171pub struct EditPredictionModelInput {
 172    project: Entity<Project>,
 173    buffer: Entity<Buffer>,
 174    snapshot: BufferSnapshot,
 175    position: Anchor,
 176    events: Vec<Arc<zeta_prompt::Event>>,
 177    related_files: Vec<RelatedFile>,
 178    recent_paths: VecDeque<ProjectPath>,
 179    trigger: PredictEditsRequestTrigger,
 180    diagnostic_search_range: Range<Point>,
 181    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 182    can_collect_data: bool,
 183    is_open_source: bool,
 184    pub user_actions: Vec<UserActionRecord>,
 185}
 186
 187#[derive(Debug)]
 188pub enum DebugEvent {
 189    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 190    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 191    EditPredictionStarted(EditPredictionStartedDebugEvent),
 192    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 193}
 194
 195#[derive(Debug)]
 196pub struct ContextRetrievalStartedDebugEvent {
 197    pub project_entity_id: EntityId,
 198    pub timestamp: Instant,
 199    pub search_prompt: String,
 200}
 201
 202#[derive(Debug)]
 203pub struct ContextRetrievalFinishedDebugEvent {
 204    pub project_entity_id: EntityId,
 205    pub timestamp: Instant,
 206    pub metadata: Vec<(&'static str, SharedString)>,
 207}
 208
 209#[derive(Debug)]
 210pub struct EditPredictionStartedDebugEvent {
 211    pub buffer: WeakEntity<Buffer>,
 212    pub position: Anchor,
 213    pub prompt: Option<String>,
 214}
 215
 216#[derive(Debug)]
 217pub struct EditPredictionFinishedDebugEvent {
 218    pub buffer: WeakEntity<Buffer>,
 219    pub position: Anchor,
 220    pub model_output: Option<String>,
 221}
 222
 223const USER_ACTION_HISTORY_SIZE: usize = 16;
 224
 225#[derive(Clone, Debug)]
 226pub struct UserActionRecord {
 227    pub action_type: UserActionType,
 228    pub buffer_id: EntityId,
 229    pub line_number: u32,
 230    pub offset: usize,
 231    pub timestamp_epoch_ms: u64,
 232}
 233
 234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 235pub enum UserActionType {
 236    InsertChar,
 237    InsertSelection,
 238    DeleteChar,
 239    DeleteSelection,
 240    CursorMovement,
 241}
 242
 243/// An event with associated metadata for reconstructing buffer state.
 244#[derive(Clone)]
 245pub struct StoredEvent {
 246    pub event: Arc<zeta_prompt::Event>,
 247    pub old_snapshot: TextBufferSnapshot,
 248    pub new_snapshot_version: clock::Global,
 249    pub total_edit_range: Range<Anchor>,
 250}
 251
 252impl StoredEvent {
 253    fn can_merge(
 254        &self,
 255        next_old_event: &StoredEvent,
 256        latest_snapshot: &TextBufferSnapshot,
 257        latest_edit_range: &Range<Anchor>,
 258    ) -> bool {
 259        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 260        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 261            return false;
 262        }
 263        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 264            return false;
 265        }
 266        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 267            return false;
 268        }
 269        if !latest_snapshot
 270            .version
 271            .observed_all(&next_old_event.new_snapshot_version)
 272        {
 273            return false;
 274        }
 275
 276        let a_is_predicted = matches!(
 277            self.event.as_ref(),
 278            zeta_prompt::Event::BufferChange {
 279                predicted: true,
 280                ..
 281            }
 282        );
 283        let b_is_predicted = matches!(
 284            next_old_event.event.as_ref(),
 285            zeta_prompt::Event::BufferChange {
 286                predicted: true,
 287                ..
 288            }
 289        );
 290
 291        // If events come from the same source (both predicted or both manual) then
 292        // we would have coalesced them already.
 293        if a_is_predicted == b_is_predicted {
 294            return false;
 295        }
 296
 297        let left_range = self.total_edit_range.to_point(latest_snapshot);
 298        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 299        let latest_range = latest_edit_range.to_point(latest_snapshot);
 300
 301        // Events near to the latest edit are not merged if their sources differ.
 302        if lines_between_ranges(&left_range, &latest_range)
 303            .min(lines_between_ranges(&right_range, &latest_range))
 304            <= CHANGE_GROUPING_LINE_SPAN
 305        {
 306            return false;
 307        }
 308
 309        // Events that are distant from each other are not merged.
 310        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 311            return false;
 312        }
 313
 314        true
 315    }
 316}
 317
 318fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 319    if left.start > right.end {
 320        return left.start.row - right.end.row;
 321    }
 322    if right.start > left.end {
 323        return right.start.row - left.end.row;
 324    }
 325    0
 326}
 327
 328struct ProjectState {
 329    events: VecDeque<StoredEvent>,
 330    last_event: Option<LastEvent>,
 331    recent_paths: VecDeque<ProjectPath>,
 332    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 333    current_prediction: Option<CurrentEditPrediction>,
 334    next_pending_prediction_id: usize,
 335    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 336    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 337    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 338    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 339    cancelled_predictions: HashSet<usize>,
 340    context: Entity<RelatedExcerptStore>,
 341    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 342    user_actions: VecDeque<UserActionRecord>,
 343    _subscriptions: [gpui::Subscription; 2],
 344    copilot: Option<Entity<Copilot>>,
 345}
 346
 347impl ProjectState {
 348    fn record_user_action(&mut self, action: UserActionRecord) {
 349        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 350            self.user_actions.pop_front();
 351        }
 352        self.user_actions.push_back(action);
 353    }
 354
 355    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 356        self.events
 357            .iter()
 358            .cloned()
 359            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 360                let (one, two) = event.split_by_pause();
 361                let one = one.finalize(&self.license_detection_watchers, cx);
 362                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 363                one.into_iter().chain(two)
 364            }))
 365            .collect()
 366    }
 367
 368    fn cancel_pending_prediction(
 369        &mut self,
 370        pending_prediction: PendingPrediction,
 371        cx: &mut Context<EditPredictionStore>,
 372    ) {
 373        self.cancelled_predictions.insert(pending_prediction.id);
 374
 375        if pending_prediction.drop_on_cancel {
 376            drop(pending_prediction.task);
 377        } else {
 378            cx.spawn(async move |this, cx| {
 379                let Some(prediction_id) = pending_prediction.task.await else {
 380                    return;
 381                };
 382
 383                this.update(cx, |this, cx| {
 384                    this.reject_prediction(
 385                        prediction_id,
 386                        EditPredictionRejectReason::Canceled,
 387                        false,
 388                        None,
 389                        None,
 390                        cx,
 391                    );
 392                })
 393                .ok();
 394            })
 395            .detach()
 396        }
 397    }
 398
 399    fn active_buffer(
 400        &self,
 401        project: &Entity<Project>,
 402        cx: &App,
 403    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 404        let project = project.read(cx);
 405        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 406        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 407        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 408        Some((active_buffer, registered_buffer.last_position))
 409    }
 410}
 411
 412#[derive(Debug, Clone)]
 413struct CurrentEditPrediction {
 414    pub requested_by: PredictionRequestedBy,
 415    pub prediction: EditPrediction,
 416    pub was_shown: bool,
 417    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 418    pub e2e_latency: std::time::Duration,
 419}
 420
 421impl CurrentEditPrediction {
 422    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 423        let Some(new_edits) = self
 424            .prediction
 425            .interpolate(&self.prediction.buffer.read(cx))
 426        else {
 427            return false;
 428        };
 429
 430        if self.prediction.buffer != old_prediction.prediction.buffer {
 431            return true;
 432        }
 433
 434        let Some(old_edits) = old_prediction
 435            .prediction
 436            .interpolate(&old_prediction.prediction.buffer.read(cx))
 437        else {
 438            return true;
 439        };
 440
 441        let requested_by_buffer_id = self.requested_by.buffer_id();
 442
 443        // This reduces the occurrence of UI thrash from replacing edits
 444        //
 445        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 446        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 447            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 448            && old_edits.len() == 1
 449            && new_edits.len() == 1
 450        {
 451            let (old_range, old_text) = &old_edits[0];
 452            let (new_range, new_text) = &new_edits[0];
 453            new_range == old_range && new_text.starts_with(old_text.as_ref())
 454        } else {
 455            true
 456        }
 457    }
 458}
 459
 460#[derive(Debug, Clone)]
 461enum PredictionRequestedBy {
 462    DiagnosticsUpdate,
 463    Buffer(EntityId),
 464}
 465
 466impl PredictionRequestedBy {
 467    pub fn buffer_id(&self) -> Option<EntityId> {
 468        match self {
 469            PredictionRequestedBy::DiagnosticsUpdate => None,
 470            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 471        }
 472    }
 473}
 474
 475const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 476
 477#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 478pub enum DiagnosticSearchScope {
 479    Local,
 480    Global,
 481}
 482
 483#[derive(Debug)]
 484struct PendingPrediction {
 485    id: usize,
 486    task: Task<Option<EditPredictionId>>,
 487    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 488    /// If false, the task is awaited to completion so rejection can be reported.
 489    drop_on_cancel: bool,
 490}
 491
 492/// A prediction from the perspective of a buffer.
 493#[derive(Debug)]
 494enum BufferEditPrediction<'a> {
 495    Local { prediction: &'a EditPrediction },
 496    Jump { prediction: &'a EditPrediction },
 497}
 498
 499#[cfg(test)]
 500impl std::ops::Deref for BufferEditPrediction<'_> {
 501    type Target = EditPrediction;
 502
 503    fn deref(&self) -> &Self::Target {
 504        match self {
 505            BufferEditPrediction::Local { prediction } => prediction,
 506            BufferEditPrediction::Jump { prediction } => prediction,
 507        }
 508    }
 509}
 510
 511#[derive(Clone)]
 512
 513struct PendingSettledPrediction {
 514    request_id: EditPredictionId,
 515    editable_anchor_range: Range<Anchor>,
 516    example: Option<ExampleSpec>,
 517    enqueued_at: Instant,
 518    last_edit_at: Instant,
 519    e2e_latency: std::time::Duration,
 520}
 521
 522struct RegisteredBuffer {
 523    file: Option<Arc<dyn File>>,
 524    snapshot: TextBufferSnapshot,
 525    pending_predictions: Vec<PendingSettledPrediction>,
 526    last_position: Option<Anchor>,
 527    _subscriptions: [gpui::Subscription; 2],
 528}
 529
 530#[derive(Clone)]
 531struct LastEvent {
 532    old_snapshot: TextBufferSnapshot,
 533    new_snapshot: TextBufferSnapshot,
 534    old_file: Option<Arc<dyn File>>,
 535    new_file: Option<Arc<dyn File>>,
 536    latest_edit_range: Range<Anchor>,
 537    total_edit_range: Range<Anchor>,
 538    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 539    predicted: bool,
 540    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 541    last_edit_time: Option<Instant>,
 542}
 543
 544impl LastEvent {
 545    pub fn finalize(
 546        &self,
 547        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 548        cx: &App,
 549    ) -> Option<StoredEvent> {
 550        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 551        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 552
 553        let in_open_source_repo =
 554            [self.new_file.as_ref(), self.old_file.as_ref()]
 555                .iter()
 556                .all(|file| {
 557                    file.is_some_and(|file| {
 558                        license_detection_watchers
 559                            .get(&file.worktree_id(cx))
 560                            .is_some_and(|watcher| watcher.is_project_open_source())
 561                    })
 562                });
 563
 564        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 565            &self.old_snapshot,
 566            &self.new_snapshot,
 567            &self.total_edit_range,
 568        )?;
 569
 570        if path == old_path && diff.is_empty() {
 571            None
 572        } else {
 573            Some(StoredEvent {
 574                event: Arc::new(zeta_prompt::Event::BufferChange {
 575                    old_path,
 576                    path,
 577                    diff,
 578                    in_open_source_repo,
 579                    predicted: self.predicted,
 580                }),
 581                old_snapshot: self.old_snapshot.clone(),
 582                new_snapshot_version: self.new_snapshot.version.clone(),
 583                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 584                    ..self.new_snapshot.anchor_before(edit_range.end),
 585            })
 586        }
 587    }
 588
 589    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 590        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 591            return (self.clone(), None);
 592        };
 593
 594        let total_edit_range_before_pause = self
 595            .total_edit_range_at_last_pause_boundary
 596            .clone()
 597            .unwrap_or_else(|| self.total_edit_range.clone());
 598
 599        let Some(total_edit_range_after_pause) =
 600            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 601        else {
 602            return (self.clone(), None);
 603        };
 604
 605        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 606        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 607
 608        let before = LastEvent {
 609            old_snapshot: self.old_snapshot.clone(),
 610            new_snapshot: boundary_snapshot.clone(),
 611            old_file: self.old_file.clone(),
 612            new_file: self.new_file.clone(),
 613            latest_edit_range: latest_edit_range_before_pause,
 614            total_edit_range: total_edit_range_before_pause,
 615            total_edit_range_at_last_pause_boundary: None,
 616            predicted: self.predicted,
 617            snapshot_after_last_editing_pause: None,
 618            last_edit_time: self.last_edit_time,
 619        };
 620
 621        let after = LastEvent {
 622            old_snapshot: boundary_snapshot.clone(),
 623            new_snapshot: self.new_snapshot.clone(),
 624            old_file: self.old_file.clone(),
 625            new_file: self.new_file.clone(),
 626            latest_edit_range: latest_edit_range_after_pause,
 627            total_edit_range: total_edit_range_after_pause,
 628            total_edit_range_at_last_pause_boundary: None,
 629            predicted: self.predicted,
 630            snapshot_after_last_editing_pause: None,
 631            last_edit_time: self.last_edit_time,
 632        };
 633
 634        (before, Some(after))
 635    }
 636}
 637
 638fn compute_total_edit_range_between_snapshots(
 639    old_snapshot: &TextBufferSnapshot,
 640    new_snapshot: &TextBufferSnapshot,
 641) -> Option<Range<Anchor>> {
 642    let edits: Vec<Edit<usize>> = new_snapshot
 643        .edits_since::<usize>(&old_snapshot.version)
 644        .collect();
 645
 646    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 647    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 648    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 649
 650    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 651}
 652
 653fn compute_old_range_for_new_range(
 654    old_snapshot: &TextBufferSnapshot,
 655    new_snapshot: &TextBufferSnapshot,
 656    total_edit_range: &Range<Anchor>,
 657) -> Option<Range<Point>> {
 658    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 659    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 660
 661    let edits: Vec<Edit<usize>> = new_snapshot
 662        .edits_since::<usize>(&old_snapshot.version)
 663        .collect();
 664    let mut old_start_offset = None;
 665    let mut old_end_offset = None;
 666    let mut delta: isize = 0;
 667
 668    for edit in &edits {
 669        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 670            old_start_offset = Some(if new_start_offset < edit.new.start {
 671                new_start_offset.checked_add_signed(-delta)?
 672            } else {
 673                edit.old.start
 674            });
 675        }
 676
 677        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 678            old_end_offset = Some(if new_end_offset < edit.new.start {
 679                new_end_offset.checked_add_signed(-delta)?
 680            } else {
 681                edit.old.end
 682            });
 683        }
 684
 685        delta += edit.new.len() as isize - edit.old.len() as isize;
 686    }
 687
 688    let old_start_offset =
 689        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 690    let old_end_offset =
 691        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 692
 693    Some(
 694        old_snapshot.offset_to_point(old_start_offset)
 695            ..old_snapshot.offset_to_point(old_end_offset),
 696    )
 697}
 698
 699fn compute_diff_between_snapshots_in_range(
 700    old_snapshot: &TextBufferSnapshot,
 701    new_snapshot: &TextBufferSnapshot,
 702    total_edit_range: &Range<Anchor>,
 703) -> Option<(String, Range<Point>)> {
 704    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 705    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 706    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 707    let old_start_point = old_range.start;
 708    let old_end_point = old_range.end;
 709
 710    const CONTEXT_LINES: u32 = 3;
 711
 712    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 713    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 714    let old_context_end_row =
 715        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 716    let new_context_end_row =
 717        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 718
 719    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 720    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 721    let old_end_line_offset = old_snapshot
 722        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 723    let new_end_line_offset = new_snapshot
 724        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 725    let old_edit_range = old_start_line_offset..old_end_line_offset;
 726    let new_edit_range = new_start_line_offset..new_end_line_offset;
 727
 728    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 729        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 730    {
 731        return None;
 732    }
 733
 734    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 735    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 736
 737    let diff = language::unified_diff_with_offsets(
 738        &old_region_text,
 739        &new_region_text,
 740        old_context_start_row,
 741        new_context_start_row,
 742    );
 743
 744    Some((diff, new_start_point..new_end_point))
 745}
 746
 747fn buffer_path_with_id_fallback(
 748    file: Option<&Arc<dyn File>>,
 749    snapshot: &TextBufferSnapshot,
 750    cx: &App,
 751) -> Arc<Path> {
 752    if let Some(file) = file {
 753        file.full_path(cx).into()
 754    } else {
 755        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 756    }
 757}
 758
 759impl EditPredictionStore {
 760    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 761        cx.try_global::<EditPredictionStoreGlobal>()
 762            .map(|global| global.0.clone())
 763    }
 764
 765    pub fn global(
 766        client: &Arc<Client>,
 767        user_store: &Entity<UserStore>,
 768        cx: &mut App,
 769    ) -> Entity<Self> {
 770        cx.try_global::<EditPredictionStoreGlobal>()
 771            .map(|global| global.0.clone())
 772            .unwrap_or_else(|| {
 773                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 774                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 775                ep_store
 776            })
 777    }
 778
 779    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 780        let data_collection_choice = Self::load_data_collection_choice(cx);
 781
 782        let llm_token = LlmApiToken::global(cx);
 783
 784        let (reject_tx, reject_rx) = mpsc::unbounded();
 785        cx.background_spawn({
 786            let client = client.clone();
 787            let llm_token = llm_token.clone();
 788            let app_version = AppVersion::global(cx);
 789            let background_executor = cx.background_executor().clone();
 790            async move {
 791                Self::handle_rejected_predictions(
 792                    reject_rx,
 793                    client,
 794                    llm_token,
 795                    app_version,
 796                    background_executor,
 797                )
 798                .await
 799            }
 800        })
 801        .detach();
 802
 803        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 804        cx.spawn(async move |this, cx| {
 805            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 806        })
 807        .detach();
 808
 809        let mut current_user = user_store.read(cx).watch_current_user();
 810        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 811            while current_user.borrow().is_none() {
 812                current_user.next().await;
 813            }
 814            this.update(cx, |this, cx| {
 815                this.refresh_available_experiments(cx);
 816            })
 817            .log_err();
 818        });
 819
 820        let this = Self {
 821            projects: HashMap::default(),
 822            client,
 823            user_store,
 824            llm_token,
 825            _fetch_experiments_task: fetch_experiments_task,
 826            update_required: false,
 827            edit_prediction_model: EditPredictionModel::Zeta,
 828            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 829            preferred_experiment: None,
 830            available_experiments: Vec::new(),
 831            sweep_ai: SweepAi::new(cx),
 832            mercury: Mercury::new(cx),
 833
 834            data_collection_choice,
 835            reject_predictions_tx: reject_tx,
 836            settled_predictions_tx,
 837            rated_predictions: Default::default(),
 838            shown_predictions: Default::default(),
 839            #[cfg(test)]
 840            settled_event_callback: None,
 841        };
 842
 843        this
 844    }
 845
 846    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 847        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 848        let format = ZetaFormat::parse(&version_str).ok()?;
 849        let model_id = env::var("ZED_ZETA_MODEL").ok();
 850        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 851        Some(Zeta2RawConfig {
 852            model_id,
 853            environment,
 854            format,
 855        })
 856    }
 857
 858    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 859        self.edit_prediction_model = model;
 860    }
 861
 862    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 863        self.zeta2_raw_config = Some(config);
 864    }
 865
 866    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 867        self.zeta2_raw_config.as_ref()
 868    }
 869
 870    pub fn preferred_experiment(&self) -> Option<&str> {
 871        self.preferred_experiment.as_deref()
 872    }
 873
 874    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 875        self.preferred_experiment = experiment;
 876    }
 877
 878    pub fn available_experiments(&self) -> &[String] {
 879        &self.available_experiments
 880    }
 881
 882    pub fn active_experiment(&self) -> Option<&str> {
 883        self.preferred_experiment.as_deref().or_else(|| {
 884            self.shown_predictions
 885                .iter()
 886                .find_map(|p| p.model_version.as_ref())
 887                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 888        })
 889    }
 890
 891    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 892        let client = self.client.clone();
 893        let llm_token = self.llm_token.clone();
 894        let app_version = AppVersion::global(cx);
 895        let organization_id = self
 896            .user_store
 897            .read(cx)
 898            .current_organization()
 899            .map(|organization| organization.id.clone());
 900
 901        cx.spawn(async move |this, cx| {
 902            let experiments = cx
 903                .background_spawn(async move {
 904                    let http_client = client.http_client();
 905                    let token = llm_token.acquire(&client, organization_id).await?;
 906                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 907                    let request = http_client::Request::builder()
 908                        .method(Method::GET)
 909                        .uri(url.as_ref())
 910                        .header("Authorization", format!("Bearer {}", token))
 911                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 912                        .body(Default::default())?;
 913                    let mut response = http_client.send(request).await?;
 914                    if response.status().is_success() {
 915                        let mut body = Vec::new();
 916                        response.body_mut().read_to_end(&mut body).await?;
 917                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 918                        Ok(experiments)
 919                    } else {
 920                        let mut body = String::new();
 921                        response.body_mut().read_to_string(&mut body).await?;
 922                        anyhow::bail!(
 923                            "Failed to fetch experiments: {:?}\nBody: {}",
 924                            response.status(),
 925                            body
 926                        );
 927                    }
 928                })
 929                .await?;
 930            this.update(cx, |this, cx| {
 931                this.available_experiments = experiments;
 932                cx.notify();
 933            })?;
 934            anyhow::Ok(())
 935        })
 936        .detach_and_log_err(cx);
 937    }
 938
 939    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 940        use ui::IconName;
 941        match self.edit_prediction_model {
 942            EditPredictionModel::Sweep => {
 943                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 944                    .with_disabled(IconName::SweepAiDisabled)
 945                    .with_up(IconName::SweepAiUp)
 946                    .with_down(IconName::SweepAiDown)
 947                    .with_error(IconName::SweepAiError)
 948            }
 949            EditPredictionModel::Mercury => {
 950                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 951            }
 952            EditPredictionModel::Zeta => {
 953                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 954                    .with_disabled(IconName::ZedPredictDisabled)
 955                    .with_up(IconName::ZedPredictUp)
 956                    .with_down(IconName::ZedPredictDown)
 957                    .with_error(IconName::ZedPredictError)
 958            }
 959            EditPredictionModel::Fim { .. } => {
 960                let settings = &all_language_settings(None, cx).edit_predictions;
 961                match settings.provider {
 962                    EditPredictionProvider::Ollama => {
 963                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 964                    }
 965                    _ => {
 966                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 967                    }
 968                }
 969            }
 970        }
 971    }
 972
 973    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 974        self.sweep_ai.api_token.read(cx).has_key()
 975    }
 976
 977    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 978        self.mercury.api_token.read(cx).has_key()
 979    }
 980
 981    pub fn mercury_has_payment_required_error(&self) -> bool {
 982        self.mercury.has_payment_required_error()
 983    }
 984
 985    pub fn clear_history(&mut self) {
 986        for project_state in self.projects.values_mut() {
 987            project_state.events.clear();
 988            project_state.last_event.take();
 989        }
 990    }
 991
 992    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 993        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 994            project_state.events.clear();
 995            project_state.last_event.take();
 996        }
 997    }
 998
 999    pub fn edit_history_for_project(
1000        &self,
1001        project: &Entity<Project>,
1002        cx: &App,
1003    ) -> Vec<StoredEvent> {
1004        self.projects
1005            .get(&project.entity_id())
1006            .map(|project_state| project_state.events(cx))
1007            .unwrap_or_default()
1008    }
1009
1010    pub fn context_for_project<'a>(
1011        &'a self,
1012        project: &Entity<Project>,
1013        cx: &'a mut App,
1014    ) -> Vec<RelatedFile> {
1015        self.projects
1016            .get(&project.entity_id())
1017            .map(|project_state| {
1018                project_state.context.update(cx, |context, cx| {
1019                    context
1020                        .related_files_with_buffers(cx)
1021                        .map(|(mut related_file, buffer)| {
1022                            related_file.in_open_source_repo = buffer
1023                                .read(cx)
1024                                .file()
1025                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1026                            related_file
1027                        })
1028                        .collect()
1029                })
1030            })
1031            .unwrap_or_default()
1032    }
1033
1034    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1035        self.projects
1036            .get(&project.entity_id())
1037            .and_then(|project| project.copilot.clone())
1038    }
1039
1040    pub fn start_copilot_for_project(
1041        &mut self,
1042        project: &Entity<Project>,
1043        cx: &mut Context<Self>,
1044    ) -> Option<Entity<Copilot>> {
1045        if DisableAiSettings::get(None, cx).disable_ai {
1046            return None;
1047        }
1048        let state = self.get_or_init_project(project, cx);
1049
1050        if state.copilot.is_some() {
1051            return state.copilot.clone();
1052        }
1053        let _project = project.clone();
1054        let project = project.read(cx);
1055
1056        let node = project.node_runtime().cloned();
1057        if let Some(node) = node {
1058            let next_id = project.languages().next_language_server_id();
1059            let fs = project.fs().clone();
1060
1061            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1062            state.copilot = Some(copilot.clone());
1063            Some(copilot)
1064        } else {
1065            None
1066        }
1067    }
1068
1069    pub fn context_for_project_with_buffers<'a>(
1070        &'a self,
1071        project: &Entity<Project>,
1072        cx: &'a mut App,
1073    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1074        self.projects
1075            .get(&project.entity_id())
1076            .map(|project| {
1077                project.context.update(cx, |context, cx| {
1078                    context.related_files_with_buffers(cx).collect()
1079                })
1080            })
1081            .unwrap_or_default()
1082    }
1083
1084    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1085        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1086            self.user_store.read(cx).edit_prediction_usage()
1087        } else {
1088            None
1089        }
1090    }
1091
1092    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1093        self.get_or_init_project(project, cx);
1094    }
1095
1096    pub fn register_buffer(
1097        &mut self,
1098        buffer: &Entity<Buffer>,
1099        project: &Entity<Project>,
1100        cx: &mut Context<Self>,
1101    ) {
1102        let project_state = self.get_or_init_project(project, cx);
1103        Self::register_buffer_impl(project_state, buffer, project, cx);
1104    }
1105
1106    fn get_or_init_project(
1107        &mut self,
1108        project: &Entity<Project>,
1109        cx: &mut Context<Self>,
1110    ) -> &mut ProjectState {
1111        let entity_id = project.entity_id();
1112        self.projects
1113            .entry(entity_id)
1114            .or_insert_with(|| ProjectState {
1115                context: {
1116                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1117                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1118                        this.handle_excerpt_store_event(entity_id, event);
1119                    })
1120                    .detach();
1121                    related_excerpt_store
1122                },
1123                events: VecDeque::new(),
1124                last_event: None,
1125                recent_paths: VecDeque::new(),
1126                debug_tx: None,
1127                registered_buffers: HashMap::default(),
1128                current_prediction: None,
1129                cancelled_predictions: HashSet::default(),
1130                pending_predictions: ArrayVec::new(),
1131                next_pending_prediction_id: 0,
1132                last_edit_prediction_refresh: None,
1133                last_jump_prediction_refresh: None,
1134                license_detection_watchers: HashMap::default(),
1135                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1136                _subscriptions: [
1137                    cx.subscribe(&project, Self::handle_project_event),
1138                    cx.observe_release(&project, move |this, _, cx| {
1139                        this.projects.remove(&entity_id);
1140                        cx.notify();
1141                    }),
1142                ],
1143                copilot: None,
1144            })
1145    }
1146
1147    pub fn remove_project(&mut self, project: &Entity<Project>) {
1148        self.projects.remove(&project.entity_id());
1149    }
1150
1151    fn handle_excerpt_store_event(
1152        &mut self,
1153        project_entity_id: EntityId,
1154        event: &RelatedExcerptStoreEvent,
1155    ) {
1156        if let Some(project_state) = self.projects.get(&project_entity_id) {
1157            if let Some(debug_tx) = project_state.debug_tx.clone() {
1158                match event {
1159                    RelatedExcerptStoreEvent::StartedRefresh => {
1160                        debug_tx
1161                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1162                                ContextRetrievalStartedDebugEvent {
1163                                    project_entity_id: project_entity_id,
1164                                    timestamp: Instant::now(),
1165                                    search_prompt: String::new(),
1166                                },
1167                            ))
1168                            .ok();
1169                    }
1170                    RelatedExcerptStoreEvent::FinishedRefresh {
1171                        cache_hit_count,
1172                        cache_miss_count,
1173                        mean_definition_latency,
1174                        max_definition_latency,
1175                    } => {
1176                        debug_tx
1177                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1178                                ContextRetrievalFinishedDebugEvent {
1179                                    project_entity_id: project_entity_id,
1180                                    timestamp: Instant::now(),
1181                                    metadata: vec![
1182                                        (
1183                                            "Cache Hits",
1184                                            format!(
1185                                                "{}/{}",
1186                                                cache_hit_count,
1187                                                cache_hit_count + cache_miss_count
1188                                            )
1189                                            .into(),
1190                                        ),
1191                                        (
1192                                            "Max LSP Time",
1193                                            format!("{} ms", max_definition_latency.as_millis())
1194                                                .into(),
1195                                        ),
1196                                        (
1197                                            "Mean LSP Time",
1198                                            format!("{} ms", mean_definition_latency.as_millis())
1199                                                .into(),
1200                                        ),
1201                                    ],
1202                                },
1203                            ))
1204                            .ok();
1205                    }
1206                }
1207            }
1208        }
1209    }
1210
1211    pub fn debug_info(
1212        &mut self,
1213        project: &Entity<Project>,
1214        cx: &mut Context<Self>,
1215    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1216        let project_state = self.get_or_init_project(project, cx);
1217        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1218        project_state.debug_tx = Some(debug_watch_tx);
1219        debug_watch_rx
1220    }
1221
1222    fn handle_project_event(
1223        &mut self,
1224        project: Entity<Project>,
1225        event: &project::Event,
1226        cx: &mut Context<Self>,
1227    ) {
1228        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1229            return;
1230        }
1231        // TODO [zeta2] init with recent paths
1232        match event {
1233            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1234                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1235                    return;
1236                };
1237                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1238                if let Some(path) = path {
1239                    if let Some(ix) = project_state
1240                        .recent_paths
1241                        .iter()
1242                        .position(|probe| probe == &path)
1243                    {
1244                        project_state.recent_paths.remove(ix);
1245                    }
1246                    project_state.recent_paths.push_front(path);
1247                }
1248            }
1249            project::Event::DiagnosticsUpdated { .. } => {
1250                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1251                    self.refresh_prediction_from_diagnostics(
1252                        project,
1253                        DiagnosticSearchScope::Global,
1254                        cx,
1255                    );
1256                }
1257            }
1258            _ => (),
1259        }
1260    }
1261
1262    fn register_buffer_impl<'a>(
1263        project_state: &'a mut ProjectState,
1264        buffer: &Entity<Buffer>,
1265        project: &Entity<Project>,
1266        cx: &mut Context<Self>,
1267    ) -> &'a mut RegisteredBuffer {
1268        let buffer_id = buffer.entity_id();
1269
1270        if let Some(file) = buffer.read(cx).file() {
1271            let worktree_id = file.worktree_id(cx);
1272            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1273                project_state
1274                    .license_detection_watchers
1275                    .entry(worktree_id)
1276                    .or_insert_with(|| {
1277                        let project_entity_id = project.entity_id();
1278                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1279                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1280                            else {
1281                                return;
1282                            };
1283                            project_state
1284                                .license_detection_watchers
1285                                .remove(&worktree_id);
1286                        })
1287                        .detach();
1288                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1289                    });
1290            }
1291        }
1292
1293        match project_state.registered_buffers.entry(buffer_id) {
1294            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1295            hash_map::Entry::Vacant(entry) => {
1296                let buf = buffer.read(cx);
1297                let snapshot = buf.text_snapshot();
1298                let file = buf.file().cloned();
1299                let project_entity_id = project.entity_id();
1300                entry.insert(RegisteredBuffer {
1301                    snapshot,
1302                    file,
1303                    last_position: None,
1304                    pending_predictions: Vec::new(),
1305                    _subscriptions: [
1306                        cx.subscribe(buffer, {
1307                            let project = project.downgrade();
1308                            move |this, buffer, event, cx| {
1309                                if let language::BufferEvent::Edited { is_local } = event
1310                                    && let Some(project) = project.upgrade()
1311                                {
1312                                    this.report_changes_for_buffer(
1313                                        &buffer, &project, false, *is_local, cx,
1314                                    );
1315                                }
1316                            }
1317                        }),
1318                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1319                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1320                            else {
1321                                return;
1322                            };
1323                            project_state.registered_buffers.remove(&buffer_id);
1324                        }),
1325                    ],
1326                })
1327            }
1328        }
1329    }
1330
1331    fn report_changes_for_buffer(
1332        &mut self,
1333        buffer: &Entity<Buffer>,
1334        project: &Entity<Project>,
1335        is_predicted: bool,
1336        is_local: bool,
1337        cx: &mut Context<Self>,
1338    ) {
1339        let project_state = self.get_or_init_project(project, cx);
1340        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1341
1342        let buf = buffer.read(cx);
1343        let new_file = buf.file().cloned();
1344        let new_snapshot = buf.text_snapshot();
1345        if new_snapshot.version == registered_buffer.snapshot.version {
1346            return;
1347        }
1348        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1349        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1350        let mut num_edits = 0usize;
1351        let mut total_deleted = 0usize;
1352        let mut total_inserted = 0usize;
1353        let mut edit_range: Option<Range<Anchor>> = None;
1354        let mut last_offset: Option<usize> = None;
1355        let now = cx.background_executor().now();
1356
1357        for (edit, anchor_range) in
1358            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1359        {
1360            num_edits += 1;
1361            total_deleted += edit.old.len();
1362            total_inserted += edit.new.len();
1363            edit_range = Some(match edit_range {
1364                None => anchor_range,
1365                Some(acc) => acc.start..anchor_range.end,
1366            });
1367            last_offset = Some(edit.new.end);
1368        }
1369
1370        let Some(edit_range) = edit_range else {
1371            return;
1372        };
1373
1374        for pending_prediction in &mut registered_buffer.pending_predictions {
1375            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1376                pending_prediction.last_edit_at = now;
1377            }
1378        }
1379
1380        let include_in_history = is_local
1381            || collaborator_edit_overlaps_locality_region(
1382                project_state,
1383                project,
1384                buffer,
1385                &buf.snapshot(),
1386                &edit_range,
1387                cx,
1388            );
1389
1390        if is_local {
1391            let action_type = match (total_deleted, total_inserted, num_edits) {
1392                (0, ins, n) if ins == n => UserActionType::InsertChar,
1393                (0, _, _) => UserActionType::InsertSelection,
1394                (del, 0, n) if del == n => UserActionType::DeleteChar,
1395                (_, 0, _) => UserActionType::DeleteSelection,
1396                (_, ins, n) if ins == n => UserActionType::InsertChar,
1397                (_, _, _) => UserActionType::InsertSelection,
1398            };
1399
1400            if let Some(offset) = last_offset {
1401                let point = new_snapshot.offset_to_point(offset);
1402                let timestamp_epoch_ms = SystemTime::now()
1403                    .duration_since(UNIX_EPOCH)
1404                    .map(|d| d.as_millis() as u64)
1405                    .unwrap_or(0);
1406                project_state.record_user_action(UserActionRecord {
1407                    action_type,
1408                    buffer_id: buffer.entity_id(),
1409                    line_number: point.row,
1410                    offset,
1411                    timestamp_epoch_ms,
1412                });
1413            }
1414        }
1415
1416        if !include_in_history {
1417            return;
1418        }
1419
1420        let is_recordable_history_edit =
1421            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1422                .is_some();
1423
1424        let events = &mut project_state.events;
1425
1426        if !is_recordable_history_edit {
1427            if let Some(event) = project_state.last_event.take() {
1428                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1429                    if events.len() + 1 >= EVENT_COUNT_MAX {
1430                        events.pop_front();
1431                    }
1432                    events.push_back(event);
1433                }
1434            }
1435            return;
1436        }
1437
1438        if let Some(last_event) = project_state.last_event.as_mut() {
1439            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1440                == last_event.new_snapshot.remote_id()
1441                && old_snapshot.version == last_event.new_snapshot.version;
1442
1443            let prediction_source_changed = is_predicted != last_event.predicted;
1444
1445            let should_coalesce = is_next_snapshot_of_same_buffer
1446                && !prediction_source_changed
1447                && lines_between_ranges(
1448                    &edit_range.to_point(&new_snapshot),
1449                    &last_event.latest_edit_range.to_point(&new_snapshot),
1450                ) <= CHANGE_GROUPING_LINE_SPAN;
1451
1452            if should_coalesce {
1453                let pause_elapsed = last_event
1454                    .last_edit_time
1455                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1456                    .unwrap_or(false);
1457                if pause_elapsed {
1458                    last_event.snapshot_after_last_editing_pause =
1459                        Some(last_event.new_snapshot.clone());
1460                    last_event.total_edit_range_at_last_pause_boundary =
1461                        Some(last_event.total_edit_range.clone());
1462                }
1463
1464                last_event.latest_edit_range = edit_range.clone();
1465                last_event.total_edit_range =
1466                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1467                last_event.new_snapshot = new_snapshot;
1468                last_event.last_edit_time = Some(now);
1469                return;
1470            }
1471        }
1472
1473        if let Some(event) = project_state.last_event.take() {
1474            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1475                if events.len() + 1 >= EVENT_COUNT_MAX {
1476                    events.pop_front();
1477                }
1478                events.push_back(event);
1479            }
1480        }
1481
1482        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1483
1484        project_state.last_event = Some(LastEvent {
1485            old_file,
1486            new_file,
1487            old_snapshot,
1488            new_snapshot,
1489            latest_edit_range: edit_range.clone(),
1490            total_edit_range: edit_range,
1491            total_edit_range_at_last_pause_boundary: None,
1492            predicted: is_predicted,
1493            snapshot_after_last_editing_pause: None,
1494            last_edit_time: Some(now),
1495        });
1496    }
1497
1498    fn prediction_at(
1499        &mut self,
1500        buffer: &Entity<Buffer>,
1501        position: Option<language::Anchor>,
1502        project: &Entity<Project>,
1503        cx: &App,
1504    ) -> Option<BufferEditPrediction<'_>> {
1505        let project_state = self.projects.get_mut(&project.entity_id())?;
1506        if let Some(position) = position
1507            && let Some(buffer) = project_state
1508                .registered_buffers
1509                .get_mut(&buffer.entity_id())
1510        {
1511            buffer.last_position = Some(position);
1512        }
1513
1514        let CurrentEditPrediction {
1515            requested_by,
1516            prediction,
1517            ..
1518        } = project_state.current_prediction.as_ref()?;
1519
1520        if prediction.targets_buffer(buffer.read(cx)) {
1521            Some(BufferEditPrediction::Local { prediction })
1522        } else {
1523            let show_jump = match requested_by {
1524                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1525                    requested_by_buffer_id == &buffer.entity_id()
1526                }
1527                PredictionRequestedBy::DiagnosticsUpdate => true,
1528            };
1529
1530            if show_jump {
1531                Some(BufferEditPrediction::Jump { prediction })
1532            } else {
1533                None
1534            }
1535        }
1536    }
1537
1538    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1539        let Some(current_prediction) = self
1540            .projects
1541            .get_mut(&project.entity_id())
1542            .and_then(|project_state| project_state.current_prediction.take())
1543        else {
1544            return;
1545        };
1546
1547        self.report_changes_for_buffer(
1548            &current_prediction.prediction.buffer,
1549            project,
1550            true,
1551            true,
1552            cx,
1553        );
1554
1555        // can't hold &mut project_state ref across report_changes_for_buffer_call
1556        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1557            return;
1558        };
1559
1560        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1561            project_state.cancel_pending_prediction(pending_prediction, cx);
1562        }
1563
1564        match self.edit_prediction_model {
1565            EditPredictionModel::Sweep => {
1566                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1567            }
1568            EditPredictionModel::Mercury => {
1569                mercury::edit_prediction_accepted(
1570                    current_prediction.prediction.id,
1571                    self.client.http_client(),
1572                    cx,
1573                );
1574            }
1575            EditPredictionModel::Zeta => {
1576                let is_cloud = !matches!(
1577                    all_language_settings(None, cx).edit_predictions.provider,
1578                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1579                );
1580                if is_cloud {
1581                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1582                }
1583            }
1584            EditPredictionModel::Fim { .. } => {}
1585        }
1586    }
1587
1588    async fn handle_rejected_predictions(
1589        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1590        client: Arc<Client>,
1591        llm_token: LlmApiToken,
1592        app_version: Version,
1593        background_executor: BackgroundExecutor,
1594    ) {
1595        let mut rx = std::pin::pin!(rx.peekable());
1596        let mut batched = Vec::new();
1597
1598        while let Some(EditPredictionRejectionPayload {
1599            rejection,
1600            organization_id,
1601        }) = rx.next().await
1602        {
1603            batched.push(rejection);
1604
1605            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1606                select_biased! {
1607                    next = rx.as_mut().peek().fuse() => {
1608                        if next.is_some() {
1609                            continue;
1610                        }
1611                    }
1612                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1613                }
1614            }
1615
1616            let url = client
1617                .http_client()
1618                .build_zed_llm_url("/predict_edits/reject", &[])
1619                .unwrap();
1620
1621            let flush_count = batched
1622                .len()
1623                // in case items have accumulated after failure
1624                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1625            let start = batched.len() - flush_count;
1626
1627            let body = RejectEditPredictionsBodyRef {
1628                rejections: &batched[start..],
1629            };
1630
1631            let result = Self::send_api_request::<()>(
1632                |builder| {
1633                    let req = builder
1634                        .uri(url.as_ref())
1635                        .body(serde_json::to_string(&body)?.into());
1636                    anyhow::Ok(req?)
1637                },
1638                client.clone(),
1639                llm_token.clone(),
1640                organization_id,
1641                app_version.clone(),
1642                true,
1643            )
1644            .await;
1645
1646            if result.log_err().is_some() {
1647                batched.drain(start..);
1648            }
1649        }
1650    }
1651
1652    async fn run_settled_predictions_worker(
1653        this: WeakEntity<Self>,
1654        mut rx: UnboundedReceiver<Instant>,
1655        cx: &mut AsyncApp,
1656    ) {
1657        let mut next_wake_time: Option<Instant> = None;
1658        loop {
1659            let now = cx.background_executor().now();
1660            if let Some(wake_time) = next_wake_time.take() {
1661                cx.background_executor()
1662                    .timer(wake_time.duration_since(now))
1663                    .await;
1664            } else {
1665                let Some(new_enqueue_time) = rx.next().await else {
1666                    break;
1667                };
1668                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1669                while rx.next().now_or_never().flatten().is_some() {}
1670                continue;
1671            }
1672
1673            let Some(this) = this.upgrade() else {
1674                break;
1675            };
1676
1677            let now = cx.background_executor().now();
1678
1679            let mut oldest_edited_at = None;
1680
1681            this.update(cx, |this, _| {
1682                for (_, project_state) in this.projects.iter_mut() {
1683                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1684                        registered_buffer
1685                            .pending_predictions
1686                            .retain_mut(|pending_prediction| {
1687                                let age =
1688                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1689                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1690                                    return false;
1691                                }
1692
1693                                let quiet_for =
1694                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1695                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1696                                    let settled_editable_region = registered_buffer
1697                                        .snapshot
1698                                        .text_for_range(
1699                                            pending_prediction.editable_anchor_range.clone(),
1700                                        )
1701                                        .collect::<String>();
1702
1703                                    #[cfg(test)]
1704                                    if let Some(callback) = &this.settled_event_callback {
1705                                        callback(
1706                                            pending_prediction.request_id.clone(),
1707                                            settled_editable_region.clone(),
1708                                        );
1709                                    }
1710
1711                                    telemetry::event!(
1712                                        EDIT_PREDICTION_SETTLED_EVENT,
1713                                        request_id = pending_prediction.request_id.0.clone(),
1714                                        settled_editable_region,
1715                                        example = pending_prediction.example.take(),
1716                                        e2e_latency = pending_prediction.e2e_latency.as_millis(),
1717                                    );
1718
1719                                    return false;
1720                                }
1721
1722                                if oldest_edited_at
1723                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1724                                {
1725                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1726                                }
1727
1728                                true
1729                            });
1730                    }
1731                }
1732            });
1733
1734            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1735        }
1736    }
1737
1738    pub(crate) fn enqueue_settled_prediction(
1739        &mut self,
1740        request_id: EditPredictionId,
1741        project: &Entity<Project>,
1742        edited_buffer: &Entity<Buffer>,
1743        edited_buffer_snapshot: &BufferSnapshot,
1744        editable_offset_range: Range<usize>,
1745        example: Option<ExampleSpec>,
1746        e2e_latency: std::time::Duration,
1747        cx: &mut Context<Self>,
1748    ) {
1749        let this = &mut *self;
1750        let project_state = this.get_or_init_project(project, cx);
1751        if let Some(buffer) = project_state
1752            .registered_buffers
1753            .get_mut(&edited_buffer.entity_id())
1754        {
1755            let now = cx.background_executor().now();
1756            buffer.pending_predictions.push(PendingSettledPrediction {
1757                request_id: request_id,
1758                editable_anchor_range: edited_buffer_snapshot
1759                    .anchor_range_around(editable_offset_range),
1760                example,
1761                e2e_latency,
1762                enqueued_at: now,
1763                last_edit_at: now,
1764            });
1765            this.settled_predictions_tx.unbounded_send(now).ok();
1766        }
1767    }
1768
1769    fn reject_current_prediction(
1770        &mut self,
1771        reason: EditPredictionRejectReason,
1772        project: &Entity<Project>,
1773        cx: &App,
1774    ) {
1775        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1776            project_state.pending_predictions.clear();
1777            if let Some(prediction) = project_state.current_prediction.take() {
1778                let model_version = prediction.prediction.model_version.clone();
1779                self.reject_prediction(
1780                    prediction.prediction.id,
1781                    reason,
1782                    prediction.was_shown,
1783                    model_version,
1784                    Some(prediction.e2e_latency),
1785                    cx,
1786                );
1787            }
1788        };
1789    }
1790
1791    fn did_show_current_prediction(
1792        &mut self,
1793        project: &Entity<Project>,
1794        display_type: edit_prediction_types::SuggestionDisplayType,
1795        cx: &mut Context<Self>,
1796    ) {
1797        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1798            return;
1799        };
1800
1801        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1802            return;
1803        };
1804
1805        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1806        let previous_shown_with = current_prediction.shown_with;
1807
1808        if previous_shown_with.is_none() || !is_jump {
1809            current_prediction.shown_with = Some(display_type);
1810        }
1811
1812        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1813
1814        if is_first_non_jump_show {
1815            current_prediction.was_shown = true;
1816        }
1817
1818        let display_type_changed = previous_shown_with != Some(display_type);
1819
1820        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1821            sweep_ai::edit_prediction_shown(
1822                &self.sweep_ai,
1823                self.client.clone(),
1824                &current_prediction.prediction,
1825                display_type,
1826                cx,
1827            );
1828        }
1829
1830        if is_first_non_jump_show {
1831            self.shown_predictions
1832                .push_front(current_prediction.prediction.clone());
1833            if self.shown_predictions.len() > 50 {
1834                let completion = self.shown_predictions.pop_back().unwrap();
1835                self.rated_predictions.remove(&completion.id);
1836            }
1837        }
1838    }
1839
1840    fn reject_prediction(
1841        &mut self,
1842        prediction_id: EditPredictionId,
1843        reason: EditPredictionRejectReason,
1844        was_shown: bool,
1845        model_version: Option<String>,
1846        e2e_latency: Option<std::time::Duration>,
1847        cx: &App,
1848    ) {
1849        match self.edit_prediction_model {
1850            EditPredictionModel::Zeta => {
1851                let is_cloud = !matches!(
1852                    all_language_settings(None, cx).edit_predictions.provider,
1853                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1854                );
1855
1856                if is_cloud {
1857                    let organization_id = self
1858                        .user_store
1859                        .read(cx)
1860                        .current_organization()
1861                        .map(|organization| organization.id.clone());
1862
1863                    self.reject_predictions_tx
1864                        .unbounded_send(EditPredictionRejectionPayload {
1865                            rejection: EditPredictionRejection {
1866                                request_id: prediction_id.to_string(),
1867                                reason,
1868                                was_shown,
1869                                model_version,
1870                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1871                            },
1872                            organization_id,
1873                        })
1874                        .log_err();
1875                }
1876            }
1877            EditPredictionModel::Mercury => {
1878                mercury::edit_prediction_rejected(
1879                    prediction_id,
1880                    was_shown,
1881                    reason,
1882                    self.client.http_client(),
1883                    cx,
1884                );
1885            }
1886            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1887        }
1888    }
1889
1890    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1891        self.projects
1892            .get(&project.entity_id())
1893            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1894    }
1895
1896    pub fn refresh_prediction_from_buffer(
1897        &mut self,
1898        project: Entity<Project>,
1899        buffer: Entity<Buffer>,
1900        position: language::Anchor,
1901        cx: &mut Context<Self>,
1902    ) {
1903        self.queue_prediction_refresh(
1904            project.clone(),
1905            PredictEditsRequestTrigger::Other,
1906            buffer.entity_id(),
1907            cx,
1908            move |this, cx| {
1909                let Some(request_task) = this
1910                    .update(cx, |this, cx| {
1911                        this.request_prediction(
1912                            &project,
1913                            &buffer,
1914                            position,
1915                            PredictEditsRequestTrigger::Other,
1916                            cx,
1917                        )
1918                    })
1919                    .log_err()
1920                else {
1921                    return Task::ready(anyhow::Ok(None));
1922                };
1923
1924                cx.spawn(async move |_cx| {
1925                    request_task.await.map(|prediction_result| {
1926                        prediction_result.map(|prediction_result| {
1927                            (
1928                                prediction_result,
1929                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1930                            )
1931                        })
1932                    })
1933                })
1934            },
1935        )
1936    }
1937
1938    pub fn refresh_prediction_from_diagnostics(
1939        &mut self,
1940        project: Entity<Project>,
1941        scope: DiagnosticSearchScope,
1942        cx: &mut Context<Self>,
1943    ) {
1944        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1945            return;
1946        }
1947
1948        if currently_following(&project, cx) {
1949            return;
1950        }
1951
1952        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1953            return;
1954        };
1955
1956        // Prefer predictions from buffer
1957        if project_state.current_prediction.is_some() {
1958            log::debug!(
1959                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1960            );
1961            return;
1962        }
1963
1964        self.queue_prediction_refresh(
1965            project.clone(),
1966            PredictEditsRequestTrigger::Diagnostics,
1967            project.entity_id(),
1968            cx,
1969            move |this, cx| {
1970                let Some((active_buffer, snapshot, cursor_point)) = this
1971                    .read_with(cx, |this, cx| {
1972                        let project_state = this.projects.get(&project.entity_id())?;
1973                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1974                        let snapshot = buffer.read(cx).snapshot();
1975
1976                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1977                            return None;
1978                        }
1979
1980                        let cursor_point = position
1981                            .map(|pos| pos.to_point(&snapshot))
1982                            .unwrap_or_default();
1983
1984                        Some((buffer, snapshot, cursor_point))
1985                    })
1986                    .log_err()
1987                    .flatten()
1988                else {
1989                    return Task::ready(anyhow::Ok(None));
1990                };
1991
1992                cx.spawn(async move |cx| {
1993                    let diagnostic_search_range = match scope {
1994                        DiagnosticSearchScope::Local => {
1995                            let diagnostic_search_start =
1996                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1997                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1998                            Point::new(diagnostic_search_start, 0)
1999                                ..Point::new(diagnostic_search_end, 0)
2000                        }
2001                        DiagnosticSearchScope::Global => Default::default(),
2002                    };
2003
2004                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2005                        active_buffer,
2006                        &snapshot,
2007                        diagnostic_search_range,
2008                        cursor_point,
2009                        &project,
2010                        cx,
2011                    )
2012                    .await?
2013                    else {
2014                        return anyhow::Ok(None);
2015                    };
2016
2017                    let Some(prediction_result) = this
2018                        .update(cx, |this, cx| {
2019                            this.request_prediction(
2020                                &project,
2021                                &jump_buffer,
2022                                jump_position,
2023                                PredictEditsRequestTrigger::Diagnostics,
2024                                cx,
2025                            )
2026                        })?
2027                        .await?
2028                    else {
2029                        return anyhow::Ok(None);
2030                    };
2031
2032                    this.update(cx, |this, cx| {
2033                        Some((
2034                            if this
2035                                .get_or_init_project(&project, cx)
2036                                .current_prediction
2037                                .is_none()
2038                            {
2039                                prediction_result
2040                            } else {
2041                                EditPredictionResult {
2042                                    id: prediction_result.id,
2043                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2044                                    e2e_latency: prediction_result.e2e_latency,
2045                                }
2046                            },
2047                            PredictionRequestedBy::DiagnosticsUpdate,
2048                        ))
2049                    })
2050                })
2051            },
2052        );
2053    }
2054
2055    fn predictions_enabled_at(
2056        snapshot: &BufferSnapshot,
2057        position: Option<language::Anchor>,
2058        cx: &App,
2059    ) -> bool {
2060        let file = snapshot.file();
2061        let all_settings = all_language_settings(file, cx);
2062        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2063            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2064        {
2065            return false;
2066        }
2067
2068        if let Some(last_position) = position {
2069            let settings = snapshot.settings_at(last_position, cx);
2070
2071            if !settings.edit_predictions_disabled_in.is_empty()
2072                && let Some(scope) = snapshot.language_scope_at(last_position)
2073                && let Some(scope_name) = scope.override_name()
2074                && settings
2075                    .edit_predictions_disabled_in
2076                    .iter()
2077                    .any(|s| s == scope_name)
2078            {
2079                return false;
2080            }
2081        }
2082
2083        true
2084    }
2085
2086    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2087}
2088
2089fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2090    let Some(app_state) = AppState::try_global(cx).and_then(|app_state| app_state.upgrade()) else {
2091        return false;
2092    };
2093
2094    app_state
2095        .workspace_store
2096        .read(cx)
2097        .workspaces()
2098        .filter_map(|workspace| workspace.upgrade())
2099        .any(|workspace| {
2100            workspace.read(cx).project().entity_id() == project.entity_id()
2101                && workspace
2102                    .read(cx)
2103                    .leader_for_pane(workspace.read(cx).active_pane())
2104                    .is_some()
2105        })
2106}
2107
2108fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2109    match provider {
2110        EditPredictionProvider::Zed
2111        | EditPredictionProvider::Sweep
2112        | EditPredictionProvider::Mercury
2113        | EditPredictionProvider::Ollama
2114        | EditPredictionProvider::OpenAiCompatibleApi
2115        | EditPredictionProvider::Experimental(_) => true,
2116        EditPredictionProvider::None
2117        | EditPredictionProvider::Copilot
2118        | EditPredictionProvider::Codestral => false,
2119    }
2120}
2121
2122impl EditPredictionStore {
2123    fn queue_prediction_refresh(
2124        &mut self,
2125        project: Entity<Project>,
2126        request_trigger: PredictEditsRequestTrigger,
2127        throttle_entity: EntityId,
2128        cx: &mut Context<Self>,
2129        do_refresh: impl FnOnce(
2130            WeakEntity<Self>,
2131            &mut AsyncApp,
2132        )
2133            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2134        + 'static,
2135    ) {
2136        fn select_throttle(
2137            project_state: &mut ProjectState,
2138            request_trigger: PredictEditsRequestTrigger,
2139        ) -> &mut Option<(EntityId, Instant)> {
2140            match request_trigger {
2141                PredictEditsRequestTrigger::Diagnostics => {
2142                    &mut project_state.last_jump_prediction_refresh
2143                }
2144                _ => &mut project_state.last_edit_prediction_refresh,
2145            }
2146        }
2147
2148        let (needs_acceptance_tracking, max_pending_predictions) =
2149            match all_language_settings(None, cx).edit_predictions.provider {
2150                EditPredictionProvider::Zed
2151                | EditPredictionProvider::Sweep
2152                | EditPredictionProvider::Mercury
2153                | EditPredictionProvider::Experimental(_) => (true, 2),
2154                EditPredictionProvider::Ollama => (false, 1),
2155                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2156                EditPredictionProvider::None
2157                | EditPredictionProvider::Copilot
2158                | EditPredictionProvider::Codestral => {
2159                    log::error!("queue_prediction_refresh called with non-store provider");
2160                    return;
2161                }
2162            };
2163
2164        let drop_on_cancel = !needs_acceptance_tracking;
2165        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2166        let project_state = self.get_or_init_project(&project, cx);
2167        let pending_prediction_id = project_state.next_pending_prediction_id;
2168        project_state.next_pending_prediction_id += 1;
2169        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2170
2171        let task = cx.spawn(async move |this, cx| {
2172            let throttle_wait = this
2173                .update(cx, |this, cx| {
2174                    let project_state = this.get_or_init_project(&project, cx);
2175                    let throttle = *select_throttle(project_state, request_trigger);
2176
2177                    let now = cx.background_executor().now();
2178                    throttle.and_then(|(last_entity, last_timestamp)| {
2179                        if throttle_entity != last_entity {
2180                            return None;
2181                        }
2182                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2183                    })
2184                })
2185                .ok()
2186                .flatten();
2187
2188            if let Some(timeout) = throttle_wait {
2189                cx.background_executor().timer(timeout).await;
2190            }
2191
2192            // If this task was cancelled before the throttle timeout expired,
2193            // do not perform a request. Also skip if another task already
2194            // proceeded since we were enqueued (duplicate).
2195            let mut is_cancelled = true;
2196            this.update(cx, |this, cx| {
2197                let project_state = this.get_or_init_project(&project, cx);
2198                let was_cancelled = project_state
2199                    .cancelled_predictions
2200                    .remove(&pending_prediction_id);
2201                if was_cancelled {
2202                    return;
2203                }
2204
2205                // Another request has been already sent since this was enqueued
2206                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2207                    return;
2208                }
2209
2210                let new_refresh = (throttle_entity, cx.background_executor().now());
2211                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2212                is_cancelled = false;
2213            })
2214            .ok();
2215            if is_cancelled {
2216                return None;
2217            }
2218
2219            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2220            let new_prediction_id = new_prediction_result
2221                .as_ref()
2222                .map(|(prediction, _)| prediction.id.clone());
2223
2224            // When a prediction completes, remove it from the pending list, and cancel
2225            // any pending predictions that were enqueued before it.
2226            this.update(cx, |this, cx| {
2227                let project_state = this.get_or_init_project(&project, cx);
2228
2229                let is_cancelled = project_state
2230                    .cancelled_predictions
2231                    .remove(&pending_prediction_id);
2232
2233                let new_current_prediction = if !is_cancelled
2234                    && let Some((prediction_result, requested_by)) = new_prediction_result
2235                {
2236                    match prediction_result.prediction {
2237                        Ok(prediction) => {
2238                            let new_prediction = CurrentEditPrediction {
2239                                requested_by,
2240                                prediction,
2241                                was_shown: false,
2242                                shown_with: None,
2243                                e2e_latency: prediction_result.e2e_latency,
2244                            };
2245
2246                            if let Some(current_prediction) =
2247                                project_state.current_prediction.as_ref()
2248                            {
2249                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2250                                {
2251                                    this.reject_current_prediction(
2252                                        EditPredictionRejectReason::Replaced,
2253                                        &project,
2254                                        cx,
2255                                    );
2256
2257                                    Some(new_prediction)
2258                                } else {
2259                                    this.reject_prediction(
2260                                        new_prediction.prediction.id,
2261                                        EditPredictionRejectReason::CurrentPreferred,
2262                                        false,
2263                                        new_prediction.prediction.model_version,
2264                                        Some(new_prediction.e2e_latency),
2265                                        cx,
2266                                    );
2267                                    None
2268                                }
2269                            } else {
2270                                Some(new_prediction)
2271                            }
2272                        }
2273                        Err(reject_reason) => {
2274                            this.reject_prediction(
2275                                prediction_result.id,
2276                                reject_reason,
2277                                false,
2278                                None,
2279                                Some(prediction_result.e2e_latency),
2280                                cx,
2281                            );
2282                            None
2283                        }
2284                    }
2285                } else {
2286                    None
2287                };
2288
2289                let project_state = this.get_or_init_project(&project, cx);
2290
2291                if let Some(new_prediction) = new_current_prediction {
2292                    project_state.current_prediction = Some(new_prediction);
2293                }
2294
2295                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2296                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2297                    if pending_prediction.id == pending_prediction_id {
2298                        pending_predictions.remove(ix);
2299                        for pending_prediction in pending_predictions.drain(0..ix) {
2300                            project_state.cancel_pending_prediction(pending_prediction, cx)
2301                        }
2302                        break;
2303                    }
2304                }
2305                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2306                cx.notify();
2307            })
2308            .ok();
2309
2310            new_prediction_id
2311        });
2312
2313        if project_state.pending_predictions.len() < max_pending_predictions {
2314            project_state
2315                .pending_predictions
2316                .push(PendingPrediction {
2317                    id: pending_prediction_id,
2318                    task,
2319                    drop_on_cancel,
2320                })
2321                .unwrap();
2322        } else {
2323            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2324            project_state
2325                .pending_predictions
2326                .push(PendingPrediction {
2327                    id: pending_prediction_id,
2328                    task,
2329                    drop_on_cancel,
2330                })
2331                .unwrap();
2332            project_state.cancel_pending_prediction(pending_prediction, cx);
2333        }
2334    }
2335
2336    pub fn request_prediction(
2337        &mut self,
2338        project: &Entity<Project>,
2339        active_buffer: &Entity<Buffer>,
2340        position: language::Anchor,
2341        trigger: PredictEditsRequestTrigger,
2342        cx: &mut Context<Self>,
2343    ) -> Task<Result<Option<EditPredictionResult>>> {
2344        self.request_prediction_internal(
2345            project.clone(),
2346            active_buffer.clone(),
2347            position,
2348            trigger,
2349            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2350            cx,
2351        )
2352    }
2353
2354    fn request_prediction_internal(
2355        &mut self,
2356        project: Entity<Project>,
2357        active_buffer: Entity<Buffer>,
2358        position: language::Anchor,
2359        trigger: PredictEditsRequestTrigger,
2360        allow_jump: bool,
2361        cx: &mut Context<Self>,
2362    ) -> Task<Result<Option<EditPredictionResult>>> {
2363        self.get_or_init_project(&project, cx);
2364        let project_state = self.projects.get(&project.entity_id()).unwrap();
2365        let stored_events = project_state.events(cx);
2366        let has_events = !stored_events.is_empty();
2367        let events: Vec<Arc<zeta_prompt::Event>> =
2368            stored_events.iter().map(|e| e.event.clone()).collect();
2369        let debug_tx = project_state.debug_tx.clone();
2370
2371        let snapshot = active_buffer.read(cx).snapshot();
2372        let cursor_point = position.to_point(&snapshot);
2373        let current_offset = position.to_offset(&snapshot);
2374
2375        let mut user_actions: Vec<UserActionRecord> =
2376            project_state.user_actions.iter().cloned().collect();
2377
2378        if let Some(last_action) = user_actions.last() {
2379            if last_action.buffer_id == active_buffer.entity_id()
2380                && current_offset != last_action.offset
2381            {
2382                let timestamp_epoch_ms = SystemTime::now()
2383                    .duration_since(UNIX_EPOCH)
2384                    .map(|d| d.as_millis() as u64)
2385                    .unwrap_or(0);
2386                user_actions.push(UserActionRecord {
2387                    action_type: UserActionType::CursorMovement,
2388                    buffer_id: active_buffer.entity_id(),
2389                    line_number: cursor_point.row,
2390                    offset: current_offset,
2391                    timestamp_epoch_ms,
2392                });
2393            }
2394        }
2395        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2396        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2397        let diagnostic_search_range =
2398            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2399
2400        let related_files = self.context_for_project(&project, cx);
2401
2402        let is_open_source = snapshot
2403            .file()
2404            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2405            && events.iter().all(|event| event.in_open_source_repo())
2406            && related_files.iter().all(|file| file.in_open_source_repo);
2407
2408        let can_collect_data = !cfg!(test)
2409            && is_open_source
2410            && self.is_data_collection_enabled(cx)
2411            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2412
2413        let recent_paths = project_state.recent_paths.clone();
2414
2415        let inputs = EditPredictionModelInput {
2416            project: project.clone(),
2417            buffer: active_buffer,
2418            snapshot,
2419            position,
2420            events,
2421            related_files,
2422            recent_paths,
2423            trigger,
2424            diagnostic_search_range: diagnostic_search_range,
2425            debug_tx,
2426            user_actions,
2427            can_collect_data,
2428            is_open_source,
2429        };
2430
2431        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2432
2433        let task = match self.edit_prediction_model {
2434            EditPredictionModel::Zeta => {
2435                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2436            }
2437            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2438            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2439            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2440        };
2441
2442        cx.spawn(async move |this, cx| {
2443            let prediction = task.await?;
2444
2445            // Only fall back to diagnostics-based prediction if we got a
2446            // the model had nothing to suggest for the buffer
2447            if prediction.is_none()
2448                && allow_jump
2449                && has_events
2450                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2451            {
2452                this.update(cx, |this, cx| {
2453                    this.refresh_prediction_from_diagnostics(
2454                        project,
2455                        DiagnosticSearchScope::Local,
2456                        cx,
2457                    );
2458                })?;
2459                return anyhow::Ok(None);
2460            }
2461
2462            Ok(prediction)
2463        })
2464    }
2465
2466    pub(crate) async fn next_diagnostic_location(
2467        active_buffer: Entity<Buffer>,
2468        active_buffer_snapshot: &BufferSnapshot,
2469        active_buffer_diagnostic_search_range: Range<Point>,
2470        active_buffer_cursor_point: Point,
2471        project: &Entity<Project>,
2472        cx: &mut AsyncApp,
2473    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2474        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2475            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2476            .flat_map(|(_, _, _, selections)| {
2477                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2478            })
2479            .collect();
2480
2481        let mut jump_location = active_buffer_snapshot
2482            .diagnostic_groups(None)
2483            .into_iter()
2484            .filter_map(|(_, group)| {
2485                let range = &group.entries[group.primary_ix]
2486                    .range
2487                    .to_point(&active_buffer_snapshot);
2488                if range.overlaps(&active_buffer_diagnostic_search_range) {
2489                    return None;
2490                }
2491                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2492                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2493                });
2494                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2495                    <= DIAGNOSTIC_LINES_RANGE;
2496                if near_collaborator && !near_local {
2497                    return None;
2498                }
2499                Some(range.start)
2500            })
2501            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2502            .map(|position| {
2503                (
2504                    active_buffer.clone(),
2505                    active_buffer_snapshot.anchor_before(position),
2506                )
2507            });
2508
2509        if jump_location.is_none() {
2510            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2511                let file = buffer.file()?;
2512
2513                Some(ProjectPath {
2514                    worktree_id: file.worktree_id(cx),
2515                    path: file.path().clone(),
2516                })
2517            });
2518
2519            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2520                project
2521                    .diagnostic_summaries(false, cx)
2522                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2523                    .map(|(path, _, _)| {
2524                        let shared_prefix = path
2525                            .path
2526                            .components()
2527                            .zip(
2528                                active_buffer_path
2529                                    .as_ref()
2530                                    .map(|p| p.path.components())
2531                                    .unwrap_or_default(),
2532                            )
2533                            .take_while(|(a, b)| a == b)
2534                            .count();
2535                        (path, shared_prefix)
2536                    })
2537                    .collect()
2538            });
2539
2540            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2541
2542            for (path, _) in candidates {
2543                let candidate_buffer = project
2544                    .update(cx, |project, cx| project.open_buffer(path, cx))
2545                    .await?;
2546
2547                let (has_collaborators, diagnostic_position) =
2548                    candidate_buffer.read_with(cx, |buffer, _cx| {
2549                        let snapshot = buffer.snapshot();
2550                        let has_collaborators = snapshot
2551                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2552                            .next()
2553                            .is_some();
2554                        let position = buffer
2555                            .buffer_diagnostics(None)
2556                            .into_iter()
2557                            .min_by_key(|entry| entry.diagnostic.severity)
2558                            .map(|entry| entry.range.start);
2559                        (has_collaborators, position)
2560                    });
2561
2562                if has_collaborators {
2563                    continue;
2564                }
2565
2566                if let Some(position) = diagnostic_position {
2567                    jump_location = Some((candidate_buffer, position));
2568                    break;
2569                }
2570            }
2571        }
2572
2573        anyhow::Ok(jump_location)
2574    }
2575
2576    async fn send_raw_llm_request(
2577        request: RawCompletionRequest,
2578        client: Arc<Client>,
2579        custom_url: Option<Arc<Url>>,
2580        llm_token: LlmApiToken,
2581        organization_id: Option<OrganizationId>,
2582        app_version: Version,
2583    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2584        let url = if let Some(custom_url) = custom_url {
2585            custom_url.as_ref().clone()
2586        } else {
2587            client
2588                .http_client()
2589                .build_zed_llm_url("/predict_edits/raw", &[])?
2590        };
2591
2592        Self::send_api_request(
2593            |builder| {
2594                let req = builder
2595                    .uri(url.as_ref())
2596                    .body(serde_json::to_string(&request)?.into());
2597                Ok(req?)
2598            },
2599            client,
2600            llm_token,
2601            organization_id,
2602            app_version,
2603            true,
2604        )
2605        .await
2606    }
2607
2608    pub(crate) async fn send_v3_request(
2609        input: ZetaPromptInput,
2610        client: Arc<Client>,
2611        llm_token: LlmApiToken,
2612        organization_id: Option<OrganizationId>,
2613        app_version: Version,
2614        trigger: PredictEditsRequestTrigger,
2615    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2616        let url = client
2617            .http_client()
2618            .build_zed_llm_url("/predict_edits/v3", &[])?;
2619
2620        let request = PredictEditsV3Request { input, trigger };
2621
2622        let json_bytes = serde_json::to_vec(&request)?;
2623        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2624
2625        Self::send_api_request(
2626            |builder| {
2627                let req = builder
2628                    .uri(url.as_ref())
2629                    .header("Content-Encoding", "zstd")
2630                    .body(compressed.clone().into());
2631                Ok(req?)
2632            },
2633            client,
2634            llm_token,
2635            organization_id,
2636            app_version,
2637            true,
2638        )
2639        .await
2640    }
2641
2642    async fn send_api_request<Res>(
2643        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2644        client: Arc<Client>,
2645        llm_token: LlmApiToken,
2646        organization_id: Option<OrganizationId>,
2647        app_version: Version,
2648        require_auth: bool,
2649    ) -> Result<(Res, Option<EditPredictionUsage>)>
2650    where
2651        Res: DeserializeOwned,
2652    {
2653        let http_client = client.http_client();
2654
2655        let mut token = if require_auth {
2656            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2657        } else {
2658            llm_token
2659                .acquire(&client, organization_id.clone())
2660                .await
2661                .ok()
2662        };
2663        let mut did_retry = false;
2664
2665        loop {
2666            let request_builder = http_client::Request::builder().method(Method::POST);
2667
2668            let mut request_builder = request_builder
2669                .header("Content-Type", "application/json")
2670                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2671
2672            // Only add Authorization header if we have a token
2673            if let Some(ref token_value) = token {
2674                request_builder =
2675                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2676            }
2677
2678            let request = build(request_builder)?;
2679
2680            let mut response = http_client.send(request).await?;
2681
2682            if let Some(minimum_required_version) = response
2683                .headers()
2684                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2685                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2686            {
2687                anyhow::ensure!(
2688                    app_version >= minimum_required_version,
2689                    ZedUpdateRequiredError {
2690                        minimum_version: minimum_required_version
2691                    }
2692                );
2693            }
2694
2695            if response.status().is_success() {
2696                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2697
2698                let mut body = Vec::new();
2699                response.body_mut().read_to_end(&mut body).await?;
2700                return Ok((serde_json::from_slice(&body)?, usage));
2701            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2702                did_retry = true;
2703                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2704            } else {
2705                let mut body = String::new();
2706                response.body_mut().read_to_string(&mut body).await?;
2707                anyhow::bail!(
2708                    "Request failed with status: {:?}\nBody: {}",
2709                    response.status(),
2710                    body
2711                );
2712            }
2713        }
2714    }
2715
2716    pub fn refresh_context(
2717        &mut self,
2718        project: &Entity<Project>,
2719        buffer: &Entity<language::Buffer>,
2720        cursor_position: language::Anchor,
2721        cx: &mut Context<Self>,
2722    ) {
2723        self.get_or_init_project(project, cx)
2724            .context
2725            .update(cx, |store, cx| {
2726                store.refresh(buffer.clone(), cursor_position, cx);
2727            });
2728    }
2729
2730    #[cfg(feature = "cli-support")]
2731    pub fn set_context_for_buffer(
2732        &mut self,
2733        project: &Entity<Project>,
2734        related_files: Vec<RelatedFile>,
2735        cx: &mut Context<Self>,
2736    ) {
2737        self.get_or_init_project(project, cx)
2738            .context
2739            .update(cx, |store, cx| {
2740                store.set_related_files(related_files, cx);
2741            });
2742    }
2743
2744    #[cfg(feature = "cli-support")]
2745    pub fn set_recent_paths_for_project(
2746        &mut self,
2747        project: &Entity<Project>,
2748        paths: impl IntoIterator<Item = project::ProjectPath>,
2749        cx: &mut Context<Self>,
2750    ) {
2751        let project_state = self.get_or_init_project(project, cx);
2752        project_state.recent_paths = paths.into_iter().collect();
2753    }
2754
2755    fn is_file_open_source(
2756        &self,
2757        project: &Entity<Project>,
2758        file: &Arc<dyn File>,
2759        cx: &App,
2760    ) -> bool {
2761        if !file.is_local() || file.is_private() {
2762            return false;
2763        }
2764        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2765            return false;
2766        };
2767        project_state
2768            .license_detection_watchers
2769            .get(&file.worktree_id(cx))
2770            .as_ref()
2771            .is_some_and(|watcher| watcher.is_project_open_source())
2772    }
2773
2774    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2775        self.data_collection_choice.is_enabled(cx)
2776    }
2777
2778    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2779        let choice = KeyValueStore::global(cx)
2780            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2781            .log_err()
2782            .flatten();
2783
2784        match choice.as_deref() {
2785            Some("true") => DataCollectionChoice::Enabled,
2786            Some("false") => DataCollectionChoice::Disabled,
2787            Some(_) => {
2788                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2789                DataCollectionChoice::NotAnswered
2790            }
2791            None => DataCollectionChoice::NotAnswered,
2792        }
2793    }
2794
2795    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2796        self.data_collection_choice = self.data_collection_choice.toggle();
2797        let new_choice = self.data_collection_choice;
2798        let is_enabled = new_choice.is_enabled(cx);
2799        let kvp = KeyValueStore::global(cx);
2800        db::write_and_log(cx, move || async move {
2801            kvp.write_kvp(
2802                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2803                is_enabled.to_string(),
2804            )
2805            .await
2806        });
2807    }
2808
2809    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2810        self.shown_predictions.iter()
2811    }
2812
2813    pub fn shown_completions_len(&self) -> usize {
2814        self.shown_predictions.len()
2815    }
2816
2817    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2818        self.rated_predictions.contains(id)
2819    }
2820
2821    pub fn rate_prediction(
2822        &mut self,
2823        prediction: &EditPrediction,
2824        rating: EditPredictionRating,
2825        feedback: String,
2826        cx: &mut Context<Self>,
2827    ) {
2828        let organization = self.user_store.read(cx).current_organization();
2829
2830        self.rated_predictions.insert(prediction.id.clone());
2831
2832        cx.background_spawn({
2833            let client = self.client.clone();
2834            let prediction_id = prediction.id.to_string();
2835            let inputs = serde_json::to_value(&prediction.inputs);
2836            let output = prediction
2837                .edit_preview
2838                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2839            async move {
2840                client
2841                    .cloud_client()
2842                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2843                        organization_id: organization.map(|organization| organization.id.clone()),
2844                        request_id: prediction_id,
2845                        rating: match rating {
2846                            EditPredictionRating::Positive => "positive".to_string(),
2847                            EditPredictionRating::Negative => "negative".to_string(),
2848                        },
2849                        inputs: inputs?,
2850                        output,
2851                        feedback,
2852                    })
2853                    .await?;
2854
2855                anyhow::Ok(())
2856            }
2857        })
2858        .detach_and_log_err(cx);
2859
2860        cx.notify();
2861    }
2862}
2863
2864fn collaborator_edit_overlaps_locality_region(
2865    project_state: &ProjectState,
2866    project: &Entity<Project>,
2867    buffer: &Entity<Buffer>,
2868    snapshot: &BufferSnapshot,
2869    edit_range: &Range<Anchor>,
2870    cx: &App,
2871) -> bool {
2872    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2873        return false;
2874    };
2875
2876    if active_buffer.entity_id() != buffer.entity_id() {
2877        return false;
2878    }
2879
2880    let locality_point_range = expand_context_syntactically_then_linewise(
2881        snapshot,
2882        (position..position).to_point(snapshot),
2883        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2884    );
2885    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2886
2887    edit_range.overlaps(&locality_anchor_range, snapshot)
2888}
2889
2890fn merge_trailing_events_if_needed(
2891    events: &mut VecDeque<StoredEvent>,
2892    end_snapshot: &TextBufferSnapshot,
2893    latest_snapshot: &TextBufferSnapshot,
2894    latest_edit_range: &Range<Anchor>,
2895) {
2896    if let Some(last_event) = events.back() {
2897        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2898            return;
2899        }
2900        if !latest_snapshot
2901            .version
2902            .observed_all(&last_event.new_snapshot_version)
2903        {
2904            return;
2905        }
2906    }
2907
2908    let mut next_old_event = None;
2909    let mut mergeable_count = 0;
2910    for old_event in events.iter().rev() {
2911        if let Some(next_old_event) = next_old_event
2912            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2913        {
2914            break;
2915        }
2916        mergeable_count += 1;
2917        next_old_event = Some(old_event);
2918    }
2919
2920    if mergeable_count <= 1 {
2921        return;
2922    }
2923
2924    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2925    let oldest_event = events_to_merge.peek().unwrap();
2926    let oldest_snapshot = oldest_event.old_snapshot.clone();
2927    let newest_snapshot = end_snapshot;
2928    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2929
2930    for event in events.range(events.len() - mergeable_count + 1..) {
2931        merged_edit_range =
2932            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2933    }
2934
2935    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2936        &oldest_snapshot,
2937        newest_snapshot,
2938        &merged_edit_range,
2939    ) {
2940        let merged_event = match oldest_event.event.as_ref() {
2941            zeta_prompt::Event::BufferChange {
2942                old_path,
2943                path,
2944                in_open_source_repo,
2945                ..
2946            } => StoredEvent {
2947                event: Arc::new(zeta_prompt::Event::BufferChange {
2948                    old_path: old_path.clone(),
2949                    path: path.clone(),
2950                    diff,
2951                    in_open_source_repo: *in_open_source_repo,
2952                    predicted: events_to_merge.all(|e| {
2953                        matches!(
2954                            e.event.as_ref(),
2955                            zeta_prompt::Event::BufferChange {
2956                                predicted: true,
2957                                ..
2958                            }
2959                        )
2960                    }),
2961                }),
2962                old_snapshot: oldest_snapshot.clone(),
2963                new_snapshot_version: newest_snapshot.version.clone(),
2964                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2965                    ..newest_snapshot.anchor_before(edit_range.end),
2966            },
2967        };
2968        events.truncate(events.len() - mergeable_count);
2969        events.push_back(merged_event);
2970    }
2971}
2972
2973fn merge_anchor_ranges(
2974    left: &Range<Anchor>,
2975    right: &Range<Anchor>,
2976    snapshot: &TextBufferSnapshot,
2977) -> Range<Anchor> {
2978    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2979        left.start
2980    } else {
2981        right.start
2982    };
2983    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2984        left.end
2985    } else {
2986        right.end
2987    };
2988    start..end
2989}
2990
2991#[derive(Error, Debug)]
2992#[error(
2993    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2994)]
2995pub struct ZedUpdateRequiredError {
2996    minimum_version: Version,
2997}
2998
2999#[derive(Debug, Clone, Copy)]
3000pub enum DataCollectionChoice {
3001    NotAnswered,
3002    Enabled,
3003    Disabled,
3004}
3005
3006impl DataCollectionChoice {
3007    pub fn is_enabled(self, cx: &App) -> bool {
3008        if cx.is_staff() {
3009            return true;
3010        }
3011        match self {
3012            Self::Enabled => true,
3013            Self::NotAnswered | Self::Disabled => false,
3014        }
3015    }
3016
3017    #[must_use]
3018    pub fn toggle(&self) -> DataCollectionChoice {
3019        match self {
3020            Self::Enabled => Self::Disabled,
3021            Self::Disabled => Self::Enabled,
3022            Self::NotAnswered => Self::Enabled,
3023        }
3024    }
3025}
3026
3027impl From<bool> for DataCollectionChoice {
3028    fn from(value: bool) -> Self {
3029        match value {
3030            true => DataCollectionChoice::Enabled,
3031            false => DataCollectionChoice::Disabled,
3032        }
3033    }
3034}
3035
3036struct ZedPredictUpsell;
3037
3038impl Dismissable for ZedPredictUpsell {
3039    const KEY: &'static str = "dismissed-edit-predict-upsell";
3040
3041    fn dismissed(cx: &App) -> bool {
3042        // To make this backwards compatible with older versions of Zed, we
3043        // check if the user has seen the previous Edit Prediction Onboarding
3044        // before, by checking the data collection choice which was written to
3045        // the database once the user clicked on "Accept and Enable"
3046        let kvp = KeyValueStore::global(cx);
3047        if kvp
3048            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3049            .log_err()
3050            .is_some_and(|s| s.is_some())
3051        {
3052            return true;
3053        }
3054
3055        kvp.read_kvp(Self::KEY)
3056            .log_err()
3057            .is_some_and(|s| s.is_some())
3058    }
3059}
3060
3061pub fn should_show_upsell_modal(cx: &App) -> bool {
3062    !ZedPredictUpsell::dismissed(cx)
3063}
3064
3065pub fn init(cx: &mut App) {
3066    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3067        workspace.register_action(
3068            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3069                ZedPredictModal::toggle(
3070                    workspace,
3071                    workspace.user_store().clone(),
3072                    workspace.client().clone(),
3073                    window,
3074                    cx,
3075                )
3076            },
3077        );
3078
3079        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3080            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3081                settings
3082                    .project
3083                    .all_languages
3084                    .edit_predictions
3085                    .get_or_insert_default()
3086                    .provider = Some(EditPredictionProvider::None)
3087            });
3088        });
3089        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3090            EditPredictionStore::try_global(cx).and_then(|store| {
3091                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3092            })
3093        }
3094
3095        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3096            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3097                copilot_ui::initiate_sign_in(copilot, window, cx);
3098            }
3099        });
3100        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3101            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3102                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3103            }
3104        });
3105        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3106            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3107                copilot_ui::initiate_sign_out(copilot, window, cx);
3108            }
3109        });
3110    })
3111    .detach();
3112}