edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  79use crate::example_spec::ExampleSpec;
  80use crate::license_detection::LicenseDetectionWatcher;
  81use crate::mercury::Mercury;
  82use crate::onboarding_modal::ZedPredictModal;
  83pub use crate::prediction::EditPrediction;
  84pub use crate::prediction::EditPredictionId;
  85use crate::prediction::EditPredictionResult;
  86pub use crate::sweep_ai::SweepAi;
  87pub use capture_example::capture_example;
  88pub use language_model::ApiKeyState;
  89pub use telemetry_events::EditPredictionRating;
  90pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  91
  92actions!(
  93    edit_prediction,
  94    [
  95        /// Resets the edit prediction onboarding state.
  96        ResetOnboarding,
  97        /// Clears the edit prediction history.
  98        ClearHistory,
  99    ]
 100);
 101
 102/// Maximum number of events to track.
 103const EVENT_COUNT_MAX: usize = 10;
 104const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 105const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 106const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 107const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 108const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 109const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 110const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 111const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 112
 113pub struct EditPredictionJumpsFeatureFlag;
 114
 115impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 116    const NAME: &'static str = "edit_prediction_jumps";
 117}
 118
 119#[derive(Clone)]
 120struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 121
 122impl Global for EditPredictionStoreGlobal {}
 123
 124/// Configuration for using the raw Zeta2 endpoint.
 125/// When set, the client uses the raw endpoint and constructs the prompt itself.
 126/// The version is also used as the Baseten environment name (lowercased).
 127#[derive(Clone)]
 128pub struct Zeta2RawConfig {
 129    pub model_id: Option<String>,
 130    pub environment: Option<String>,
 131    pub format: ZetaFormat,
 132}
 133
 134pub struct EditPredictionStore {
 135    client: Arc<Client>,
 136    user_store: Entity<UserStore>,
 137    llm_token: LlmApiToken,
 138    _fetch_experiments_task: Task<()>,
 139    projects: HashMap<EntityId, ProjectState>,
 140    update_required: bool,
 141    edit_prediction_model: EditPredictionModel,
 142    zeta2_raw_config: Option<Zeta2RawConfig>,
 143    preferred_experiment: Option<String>,
 144    available_experiments: Vec<String>,
 145    pub sweep_ai: SweepAi,
 146    pub mercury: Mercury,
 147    data_collection_choice: DataCollectionChoice,
 148    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 149    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 150    shown_predictions: VecDeque<EditPrediction>,
 151    rated_predictions: HashSet<EditPredictionId>,
 152    #[cfg(test)]
 153    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 154}
 155
 156pub(crate) struct EditPredictionRejectionPayload {
 157    rejection: EditPredictionRejection,
 158    organization_id: Option<OrganizationId>,
 159}
 160
 161#[derive(Copy, Clone, PartialEq, Eq)]
 162pub enum EditPredictionModel {
 163    Zeta,
 164    Fim { format: EditPredictionPromptFormat },
 165    Sweep,
 166    Mercury,
 167}
 168
 169#[derive(Clone)]
 170pub struct EditPredictionModelInput {
 171    project: Entity<Project>,
 172    buffer: Entity<Buffer>,
 173    snapshot: BufferSnapshot,
 174    position: Anchor,
 175    events: Vec<Arc<zeta_prompt::Event>>,
 176    related_files: Vec<RelatedFile>,
 177    recent_paths: VecDeque<ProjectPath>,
 178    trigger: PredictEditsRequestTrigger,
 179    diagnostic_search_range: Range<Point>,
 180    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 181    can_collect_data: bool,
 182    is_open_source: bool,
 183    pub user_actions: Vec<UserActionRecord>,
 184}
 185
 186#[derive(Debug)]
 187pub enum DebugEvent {
 188    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 189    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 190    EditPredictionStarted(EditPredictionStartedDebugEvent),
 191    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 192}
 193
 194#[derive(Debug)]
 195pub struct ContextRetrievalStartedDebugEvent {
 196    pub project_entity_id: EntityId,
 197    pub timestamp: Instant,
 198    pub search_prompt: String,
 199}
 200
 201#[derive(Debug)]
 202pub struct ContextRetrievalFinishedDebugEvent {
 203    pub project_entity_id: EntityId,
 204    pub timestamp: Instant,
 205    pub metadata: Vec<(&'static str, SharedString)>,
 206}
 207
 208#[derive(Debug)]
 209pub struct EditPredictionStartedDebugEvent {
 210    pub buffer: WeakEntity<Buffer>,
 211    pub position: Anchor,
 212    pub prompt: Option<String>,
 213}
 214
 215#[derive(Debug)]
 216pub struct EditPredictionFinishedDebugEvent {
 217    pub buffer: WeakEntity<Buffer>,
 218    pub position: Anchor,
 219    pub model_output: Option<String>,
 220}
 221
 222const USER_ACTION_HISTORY_SIZE: usize = 16;
 223
 224#[derive(Clone, Debug)]
 225pub struct UserActionRecord {
 226    pub action_type: UserActionType,
 227    pub buffer_id: EntityId,
 228    pub line_number: u32,
 229    pub offset: usize,
 230    pub timestamp_epoch_ms: u64,
 231}
 232
 233#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 234pub enum UserActionType {
 235    InsertChar,
 236    InsertSelection,
 237    DeleteChar,
 238    DeleteSelection,
 239    CursorMovement,
 240}
 241
 242/// An event with associated metadata for reconstructing buffer state.
 243#[derive(Clone)]
 244pub struct StoredEvent {
 245    pub event: Arc<zeta_prompt::Event>,
 246    pub old_snapshot: TextBufferSnapshot,
 247    pub new_snapshot_version: clock::Global,
 248    pub total_edit_range: Range<Anchor>,
 249}
 250
 251impl StoredEvent {
 252    fn can_merge(
 253        &self,
 254        next_old_event: &StoredEvent,
 255        latest_snapshot: &TextBufferSnapshot,
 256        latest_edit_range: &Range<Anchor>,
 257    ) -> bool {
 258        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 259        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 260            return false;
 261        }
 262        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 263            return false;
 264        }
 265        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 266            return false;
 267        }
 268        if !latest_snapshot
 269            .version
 270            .observed_all(&next_old_event.new_snapshot_version)
 271        {
 272            return false;
 273        }
 274
 275        let a_is_predicted = matches!(
 276            self.event.as_ref(),
 277            zeta_prompt::Event::BufferChange {
 278                predicted: true,
 279                ..
 280            }
 281        );
 282        let b_is_predicted = matches!(
 283            next_old_event.event.as_ref(),
 284            zeta_prompt::Event::BufferChange {
 285                predicted: true,
 286                ..
 287            }
 288        );
 289
 290        // If events come from the same source (both predicted or both manual) then
 291        // we would have coalesced them already.
 292        if a_is_predicted == b_is_predicted {
 293            return false;
 294        }
 295
 296        let left_range = self.total_edit_range.to_point(latest_snapshot);
 297        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 298        let latest_range = latest_edit_range.to_point(latest_snapshot);
 299
 300        // Events near to the latest edit are not merged if their sources differ.
 301        if lines_between_ranges(&left_range, &latest_range)
 302            .min(lines_between_ranges(&right_range, &latest_range))
 303            <= CHANGE_GROUPING_LINE_SPAN
 304        {
 305            return false;
 306        }
 307
 308        // Events that are distant from each other are not merged.
 309        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 310            return false;
 311        }
 312
 313        true
 314    }
 315}
 316
 317fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 318    if left.start > right.end {
 319        return left.start.row - right.end.row;
 320    }
 321    if right.start > left.end {
 322        return right.start.row - left.end.row;
 323    }
 324    0
 325}
 326
 327struct ProjectState {
 328    events: VecDeque<StoredEvent>,
 329    last_event: Option<LastEvent>,
 330    recent_paths: VecDeque<ProjectPath>,
 331    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 332    current_prediction: Option<CurrentEditPrediction>,
 333    next_pending_prediction_id: usize,
 334    pending_predictions: ArrayVec<PendingPrediction, 2>,
 335    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 336    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 337    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 338    cancelled_predictions: HashSet<usize>,
 339    context: Entity<RelatedExcerptStore>,
 340    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 341    user_actions: VecDeque<UserActionRecord>,
 342    _subscriptions: [gpui::Subscription; 2],
 343    copilot: Option<Entity<Copilot>>,
 344}
 345
 346impl ProjectState {
 347    fn record_user_action(&mut self, action: UserActionRecord) {
 348        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 349            self.user_actions.pop_front();
 350        }
 351        self.user_actions.push_back(action);
 352    }
 353
 354    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 355        self.events
 356            .iter()
 357            .cloned()
 358            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 359                let (one, two) = event.split_by_pause();
 360                let one = one.finalize(&self.license_detection_watchers, cx);
 361                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 362                one.into_iter().chain(two)
 363            }))
 364            .collect()
 365    }
 366
 367    fn cancel_pending_prediction(
 368        &mut self,
 369        pending_prediction: PendingPrediction,
 370        cx: &mut Context<EditPredictionStore>,
 371    ) {
 372        self.cancelled_predictions.insert(pending_prediction.id);
 373
 374        if pending_prediction.drop_on_cancel {
 375            drop(pending_prediction.task);
 376        } else {
 377            cx.spawn(async move |this, cx| {
 378                let Some(prediction_id) = pending_prediction.task.await else {
 379                    return;
 380                };
 381
 382                this.update(cx, |this, cx| {
 383                    this.reject_prediction(
 384                        prediction_id,
 385                        EditPredictionRejectReason::Canceled,
 386                        false,
 387                        None,
 388                        cx,
 389                    );
 390                })
 391                .ok();
 392            })
 393            .detach()
 394        }
 395    }
 396
 397    fn active_buffer(
 398        &self,
 399        project: &Entity<Project>,
 400        cx: &App,
 401    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 402        let project = project.read(cx);
 403        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 404        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 405        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 406        Some((active_buffer, registered_buffer.last_position))
 407    }
 408}
 409
 410#[derive(Debug, Clone)]
 411struct CurrentEditPrediction {
 412    pub requested_by: PredictionRequestedBy,
 413    pub prediction: EditPrediction,
 414    pub was_shown: bool,
 415    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 416}
 417
 418impl CurrentEditPrediction {
 419    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 420        let Some(new_edits) = self
 421            .prediction
 422            .interpolate(&self.prediction.buffer.read(cx))
 423        else {
 424            return false;
 425        };
 426
 427        if self.prediction.buffer != old_prediction.prediction.buffer {
 428            return true;
 429        }
 430
 431        let Some(old_edits) = old_prediction
 432            .prediction
 433            .interpolate(&old_prediction.prediction.buffer.read(cx))
 434        else {
 435            return true;
 436        };
 437
 438        let requested_by_buffer_id = self.requested_by.buffer_id();
 439
 440        // This reduces the occurrence of UI thrash from replacing edits
 441        //
 442        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 443        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 444            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 445            && old_edits.len() == 1
 446            && new_edits.len() == 1
 447        {
 448            let (old_range, old_text) = &old_edits[0];
 449            let (new_range, new_text) = &new_edits[0];
 450            new_range == old_range && new_text.starts_with(old_text.as_ref())
 451        } else {
 452            true
 453        }
 454    }
 455}
 456
 457#[derive(Debug, Clone)]
 458enum PredictionRequestedBy {
 459    DiagnosticsUpdate,
 460    Buffer(EntityId),
 461}
 462
 463impl PredictionRequestedBy {
 464    pub fn buffer_id(&self) -> Option<EntityId> {
 465        match self {
 466            PredictionRequestedBy::DiagnosticsUpdate => None,
 467            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 468        }
 469    }
 470}
 471
 472const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 473
 474#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 475pub enum DiagnosticSearchScope {
 476    Local,
 477    Global,
 478}
 479
 480#[derive(Debug)]
 481struct PendingPrediction {
 482    id: usize,
 483    task: Task<Option<EditPredictionId>>,
 484    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 485    /// If false, the task is awaited to completion so rejection can be reported.
 486    drop_on_cancel: bool,
 487}
 488
 489/// A prediction from the perspective of a buffer.
 490#[derive(Debug)]
 491enum BufferEditPrediction<'a> {
 492    Local { prediction: &'a EditPrediction },
 493    Jump { prediction: &'a EditPrediction },
 494}
 495
 496#[cfg(test)]
 497impl std::ops::Deref for BufferEditPrediction<'_> {
 498    type Target = EditPrediction;
 499
 500    fn deref(&self) -> &Self::Target {
 501        match self {
 502            BufferEditPrediction::Local { prediction } => prediction,
 503            BufferEditPrediction::Jump { prediction } => prediction,
 504        }
 505    }
 506}
 507
 508#[derive(Clone)]
 509struct PendingSettledPrediction {
 510    request_id: EditPredictionId,
 511    editable_anchor_range: Range<Anchor>,
 512    example: Option<ExampleSpec>,
 513    enqueued_at: Instant,
 514    last_edit_at: Instant,
 515}
 516
 517struct RegisteredBuffer {
 518    file: Option<Arc<dyn File>>,
 519    snapshot: TextBufferSnapshot,
 520    pending_predictions: Vec<PendingSettledPrediction>,
 521    last_position: Option<Anchor>,
 522    _subscriptions: [gpui::Subscription; 2],
 523}
 524
 525#[derive(Clone)]
 526struct LastEvent {
 527    old_snapshot: TextBufferSnapshot,
 528    new_snapshot: TextBufferSnapshot,
 529    old_file: Option<Arc<dyn File>>,
 530    new_file: Option<Arc<dyn File>>,
 531    latest_edit_range: Range<Anchor>,
 532    total_edit_range: Range<Anchor>,
 533    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 534    predicted: bool,
 535    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 536    last_edit_time: Option<Instant>,
 537}
 538
 539impl LastEvent {
 540    pub fn finalize(
 541        &self,
 542        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 543        cx: &App,
 544    ) -> Option<StoredEvent> {
 545        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 546        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 547
 548        let in_open_source_repo =
 549            [self.new_file.as_ref(), self.old_file.as_ref()]
 550                .iter()
 551                .all(|file| {
 552                    file.is_some_and(|file| {
 553                        license_detection_watchers
 554                            .get(&file.worktree_id(cx))
 555                            .is_some_and(|watcher| watcher.is_project_open_source())
 556                    })
 557                });
 558
 559        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 560            &self.old_snapshot,
 561            &self.new_snapshot,
 562            &self.total_edit_range,
 563        )?;
 564
 565        if path == old_path && diff.is_empty() {
 566            None
 567        } else {
 568            Some(StoredEvent {
 569                event: Arc::new(zeta_prompt::Event::BufferChange {
 570                    old_path,
 571                    path,
 572                    diff,
 573                    in_open_source_repo,
 574                    predicted: self.predicted,
 575                }),
 576                old_snapshot: self.old_snapshot.clone(),
 577                new_snapshot_version: self.new_snapshot.version.clone(),
 578                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 579                    ..self.new_snapshot.anchor_before(edit_range.end),
 580            })
 581        }
 582    }
 583
 584    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 585        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 586            return (self.clone(), None);
 587        };
 588
 589        let total_edit_range_before_pause = self
 590            .total_edit_range_at_last_pause_boundary
 591            .clone()
 592            .unwrap_or_else(|| self.total_edit_range.clone());
 593
 594        let Some(total_edit_range_after_pause) =
 595            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 596        else {
 597            return (self.clone(), None);
 598        };
 599
 600        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 601        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 602
 603        let before = LastEvent {
 604            old_snapshot: self.old_snapshot.clone(),
 605            new_snapshot: boundary_snapshot.clone(),
 606            old_file: self.old_file.clone(),
 607            new_file: self.new_file.clone(),
 608            latest_edit_range: latest_edit_range_before_pause,
 609            total_edit_range: total_edit_range_before_pause,
 610            total_edit_range_at_last_pause_boundary: None,
 611            predicted: self.predicted,
 612            snapshot_after_last_editing_pause: None,
 613            last_edit_time: self.last_edit_time,
 614        };
 615
 616        let after = LastEvent {
 617            old_snapshot: boundary_snapshot.clone(),
 618            new_snapshot: self.new_snapshot.clone(),
 619            old_file: self.old_file.clone(),
 620            new_file: self.new_file.clone(),
 621            latest_edit_range: latest_edit_range_after_pause,
 622            total_edit_range: total_edit_range_after_pause,
 623            total_edit_range_at_last_pause_boundary: None,
 624            predicted: self.predicted,
 625            snapshot_after_last_editing_pause: None,
 626            last_edit_time: self.last_edit_time,
 627        };
 628
 629        (before, Some(after))
 630    }
 631}
 632
 633fn compute_total_edit_range_between_snapshots(
 634    old_snapshot: &TextBufferSnapshot,
 635    new_snapshot: &TextBufferSnapshot,
 636) -> Option<Range<Anchor>> {
 637    let edits: Vec<Edit<usize>> = new_snapshot
 638        .edits_since::<usize>(&old_snapshot.version)
 639        .collect();
 640
 641    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 642    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 643    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 644
 645    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 646}
 647
 648fn compute_old_range_for_new_range(
 649    old_snapshot: &TextBufferSnapshot,
 650    new_snapshot: &TextBufferSnapshot,
 651    total_edit_range: &Range<Anchor>,
 652) -> Option<Range<Point>> {
 653    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 654    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 655
 656    let edits: Vec<Edit<usize>> = new_snapshot
 657        .edits_since::<usize>(&old_snapshot.version)
 658        .collect();
 659    let mut old_start_offset = None;
 660    let mut old_end_offset = None;
 661    let mut delta: isize = 0;
 662
 663    for edit in &edits {
 664        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 665            old_start_offset = Some(if new_start_offset < edit.new.start {
 666                new_start_offset.checked_add_signed(-delta)?
 667            } else {
 668                edit.old.start
 669            });
 670        }
 671
 672        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 673            old_end_offset = Some(if new_end_offset < edit.new.start {
 674                new_end_offset.checked_add_signed(-delta)?
 675            } else {
 676                edit.old.end
 677            });
 678        }
 679
 680        delta += edit.new.len() as isize - edit.old.len() as isize;
 681    }
 682
 683    let old_start_offset =
 684        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 685    let old_end_offset =
 686        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 687
 688    Some(
 689        old_snapshot.offset_to_point(old_start_offset)
 690            ..old_snapshot.offset_to_point(old_end_offset),
 691    )
 692}
 693
 694fn compute_diff_between_snapshots_in_range(
 695    old_snapshot: &TextBufferSnapshot,
 696    new_snapshot: &TextBufferSnapshot,
 697    total_edit_range: &Range<Anchor>,
 698) -> Option<(String, Range<Point>)> {
 699    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 700    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 701    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 702    let old_start_point = old_range.start;
 703    let old_end_point = old_range.end;
 704
 705    const CONTEXT_LINES: u32 = 3;
 706
 707    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 708    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 709    let old_context_end_row =
 710        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 711    let new_context_end_row =
 712        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 713
 714    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 715    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 716    let old_end_line_offset = old_snapshot
 717        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 718    let new_end_line_offset = new_snapshot
 719        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 720    let old_edit_range = old_start_line_offset..old_end_line_offset;
 721    let new_edit_range = new_start_line_offset..new_end_line_offset;
 722
 723    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 724    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 725
 726    let diff = language::unified_diff_with_offsets(
 727        &old_region_text,
 728        &new_region_text,
 729        old_context_start_row,
 730        new_context_start_row,
 731    );
 732
 733    Some((diff, new_start_point..new_end_point))
 734}
 735
 736fn buffer_path_with_id_fallback(
 737    file: Option<&Arc<dyn File>>,
 738    snapshot: &TextBufferSnapshot,
 739    cx: &App,
 740) -> Arc<Path> {
 741    if let Some(file) = file {
 742        file.full_path(cx).into()
 743    } else {
 744        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 745    }
 746}
 747
 748impl EditPredictionStore {
 749    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 750        cx.try_global::<EditPredictionStoreGlobal>()
 751            .map(|global| global.0.clone())
 752    }
 753
 754    pub fn global(
 755        client: &Arc<Client>,
 756        user_store: &Entity<UserStore>,
 757        cx: &mut App,
 758    ) -> Entity<Self> {
 759        cx.try_global::<EditPredictionStoreGlobal>()
 760            .map(|global| global.0.clone())
 761            .unwrap_or_else(|| {
 762                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 763                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 764                ep_store
 765            })
 766    }
 767
 768    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 769        let data_collection_choice = Self::load_data_collection_choice();
 770
 771        let llm_token = LlmApiToken::global(cx);
 772
 773        let (reject_tx, reject_rx) = mpsc::unbounded();
 774        cx.background_spawn({
 775            let client = client.clone();
 776            let llm_token = llm_token.clone();
 777            let app_version = AppVersion::global(cx);
 778            let background_executor = cx.background_executor().clone();
 779            async move {
 780                Self::handle_rejected_predictions(
 781                    reject_rx,
 782                    client,
 783                    llm_token,
 784                    app_version,
 785                    background_executor,
 786                )
 787                .await
 788            }
 789        })
 790        .detach();
 791
 792        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 793        cx.spawn(async move |this, cx| {
 794            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 795        })
 796        .detach();
 797
 798        let mut current_user = user_store.read(cx).watch_current_user();
 799        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 800            while current_user.borrow().is_none() {
 801                current_user.next().await;
 802            }
 803            this.update(cx, |this, cx| {
 804                this.refresh_available_experiments(cx);
 805            })
 806            .log_err();
 807        });
 808
 809        let this = Self {
 810            projects: HashMap::default(),
 811            client,
 812            user_store,
 813            llm_token,
 814            _fetch_experiments_task: fetch_experiments_task,
 815            update_required: false,
 816            edit_prediction_model: EditPredictionModel::Zeta,
 817            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 818            preferred_experiment: None,
 819            available_experiments: Vec::new(),
 820            sweep_ai: SweepAi::new(cx),
 821            mercury: Mercury::new(cx),
 822
 823            data_collection_choice,
 824            reject_predictions_tx: reject_tx,
 825            settled_predictions_tx,
 826            rated_predictions: Default::default(),
 827            shown_predictions: Default::default(),
 828            #[cfg(test)]
 829            settled_event_callback: None,
 830        };
 831
 832        this
 833    }
 834
 835    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 836        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 837        let format = ZetaFormat::parse(&version_str).ok()?;
 838        let model_id = env::var("ZED_ZETA_MODEL").ok();
 839        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 840        Some(Zeta2RawConfig {
 841            model_id,
 842            environment,
 843            format,
 844        })
 845    }
 846
 847    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 848        self.edit_prediction_model = model;
 849    }
 850
 851    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 852        self.zeta2_raw_config = Some(config);
 853    }
 854
 855    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 856        self.zeta2_raw_config.as_ref()
 857    }
 858
 859    pub fn preferred_experiment(&self) -> Option<&str> {
 860        self.preferred_experiment.as_deref()
 861    }
 862
 863    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 864        self.preferred_experiment = experiment;
 865    }
 866
 867    pub fn available_experiments(&self) -> &[String] {
 868        &self.available_experiments
 869    }
 870
 871    pub fn active_experiment(&self) -> Option<&str> {
 872        self.preferred_experiment.as_deref().or_else(|| {
 873            self.shown_predictions
 874                .iter()
 875                .find_map(|p| p.model_version.as_ref())
 876                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 877        })
 878    }
 879
 880    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 881        let client = self.client.clone();
 882        let llm_token = self.llm_token.clone();
 883        let app_version = AppVersion::global(cx);
 884        let organization_id = self
 885            .user_store
 886            .read(cx)
 887            .current_organization()
 888            .map(|organization| organization.id.clone());
 889
 890        cx.spawn(async move |this, cx| {
 891            let experiments = cx
 892                .background_spawn(async move {
 893                    let http_client = client.http_client();
 894                    let token = llm_token.acquire(&client, organization_id).await?;
 895                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 896                    let request = http_client::Request::builder()
 897                        .method(Method::GET)
 898                        .uri(url.as_ref())
 899                        .header("Authorization", format!("Bearer {}", token))
 900                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 901                        .body(Default::default())?;
 902                    let mut response = http_client.send(request).await?;
 903                    if response.status().is_success() {
 904                        let mut body = Vec::new();
 905                        response.body_mut().read_to_end(&mut body).await?;
 906                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 907                        Ok(experiments)
 908                    } else {
 909                        let mut body = String::new();
 910                        response.body_mut().read_to_string(&mut body).await?;
 911                        anyhow::bail!(
 912                            "Failed to fetch experiments: {:?}\nBody: {}",
 913                            response.status(),
 914                            body
 915                        );
 916                    }
 917                })
 918                .await?;
 919            this.update(cx, |this, cx| {
 920                this.available_experiments = experiments;
 921                cx.notify();
 922            })?;
 923            anyhow::Ok(())
 924        })
 925        .detach_and_log_err(cx);
 926    }
 927
 928    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 929        use ui::IconName;
 930        match self.edit_prediction_model {
 931            EditPredictionModel::Sweep => {
 932                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 933                    .with_disabled(IconName::SweepAiDisabled)
 934                    .with_up(IconName::SweepAiUp)
 935                    .with_down(IconName::SweepAiDown)
 936                    .with_error(IconName::SweepAiError)
 937            }
 938            EditPredictionModel::Mercury => {
 939                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 940            }
 941            EditPredictionModel::Zeta => {
 942                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 943                    .with_disabled(IconName::ZedPredictDisabled)
 944                    .with_up(IconName::ZedPredictUp)
 945                    .with_down(IconName::ZedPredictDown)
 946                    .with_error(IconName::ZedPredictError)
 947            }
 948            EditPredictionModel::Fim { .. } => {
 949                let settings = &all_language_settings(None, cx).edit_predictions;
 950                match settings.provider {
 951                    EditPredictionProvider::Ollama => {
 952                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 953                    }
 954                    _ => {
 955                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 956                    }
 957                }
 958            }
 959        }
 960    }
 961
 962    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 963        self.sweep_ai.api_token.read(cx).has_key()
 964    }
 965
 966    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 967        self.mercury.api_token.read(cx).has_key()
 968    }
 969
 970    pub fn clear_history(&mut self) {
 971        for project_state in self.projects.values_mut() {
 972            project_state.events.clear();
 973            project_state.last_event.take();
 974        }
 975    }
 976
 977    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 978        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 979            project_state.events.clear();
 980            project_state.last_event.take();
 981        }
 982    }
 983
 984    pub fn edit_history_for_project(
 985        &self,
 986        project: &Entity<Project>,
 987        cx: &App,
 988    ) -> Vec<StoredEvent> {
 989        self.projects
 990            .get(&project.entity_id())
 991            .map(|project_state| project_state.events(cx))
 992            .unwrap_or_default()
 993    }
 994
 995    pub fn context_for_project<'a>(
 996        &'a self,
 997        project: &Entity<Project>,
 998        cx: &'a mut App,
 999    ) -> Vec<RelatedFile> {
1000        self.projects
1001            .get(&project.entity_id())
1002            .map(|project_state| {
1003                project_state.context.update(cx, |context, cx| {
1004                    context
1005                        .related_files_with_buffers(cx)
1006                        .map(|(mut related_file, buffer)| {
1007                            related_file.in_open_source_repo = buffer
1008                                .read(cx)
1009                                .file()
1010                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1011                            related_file
1012                        })
1013                        .collect()
1014                })
1015            })
1016            .unwrap_or_default()
1017    }
1018
1019    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1020        self.projects
1021            .get(&project.entity_id())
1022            .and_then(|project| project.copilot.clone())
1023    }
1024
1025    pub fn start_copilot_for_project(
1026        &mut self,
1027        project: &Entity<Project>,
1028        cx: &mut Context<Self>,
1029    ) -> Option<Entity<Copilot>> {
1030        if DisableAiSettings::get(None, cx).disable_ai {
1031            return None;
1032        }
1033        let state = self.get_or_init_project(project, cx);
1034
1035        if state.copilot.is_some() {
1036            return state.copilot.clone();
1037        }
1038        let _project = project.clone();
1039        let project = project.read(cx);
1040
1041        let node = project.node_runtime().cloned();
1042        if let Some(node) = node {
1043            let next_id = project.languages().next_language_server_id();
1044            let fs = project.fs().clone();
1045
1046            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1047            state.copilot = Some(copilot.clone());
1048            Some(copilot)
1049        } else {
1050            None
1051        }
1052    }
1053
1054    pub fn context_for_project_with_buffers<'a>(
1055        &'a self,
1056        project: &Entity<Project>,
1057        cx: &'a mut App,
1058    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1059        self.projects
1060            .get(&project.entity_id())
1061            .map(|project| {
1062                project.context.update(cx, |context, cx| {
1063                    context.related_files_with_buffers(cx).collect()
1064                })
1065            })
1066            .unwrap_or_default()
1067    }
1068
1069    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1070        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1071            self.user_store.read(cx).edit_prediction_usage()
1072        } else {
1073            None
1074        }
1075    }
1076
1077    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1078        self.get_or_init_project(project, cx);
1079    }
1080
1081    pub fn register_buffer(
1082        &mut self,
1083        buffer: &Entity<Buffer>,
1084        project: &Entity<Project>,
1085        cx: &mut Context<Self>,
1086    ) {
1087        let project_state = self.get_or_init_project(project, cx);
1088        Self::register_buffer_impl(project_state, buffer, project, cx);
1089    }
1090
1091    fn get_or_init_project(
1092        &mut self,
1093        project: &Entity<Project>,
1094        cx: &mut Context<Self>,
1095    ) -> &mut ProjectState {
1096        let entity_id = project.entity_id();
1097        self.projects
1098            .entry(entity_id)
1099            .or_insert_with(|| ProjectState {
1100                context: {
1101                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1102                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1103                        this.handle_excerpt_store_event(entity_id, event);
1104                    })
1105                    .detach();
1106                    related_excerpt_store
1107                },
1108                events: VecDeque::new(),
1109                last_event: None,
1110                recent_paths: VecDeque::new(),
1111                debug_tx: None,
1112                registered_buffers: HashMap::default(),
1113                current_prediction: None,
1114                cancelled_predictions: HashSet::default(),
1115                pending_predictions: ArrayVec::new(),
1116                next_pending_prediction_id: 0,
1117                last_edit_prediction_refresh: None,
1118                last_jump_prediction_refresh: None,
1119                license_detection_watchers: HashMap::default(),
1120                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1121                _subscriptions: [
1122                    cx.subscribe(&project, Self::handle_project_event),
1123                    cx.observe_release(&project, move |this, _, cx| {
1124                        this.projects.remove(&entity_id);
1125                        cx.notify();
1126                    }),
1127                ],
1128                copilot: None,
1129            })
1130    }
1131
1132    pub fn remove_project(&mut self, project: &Entity<Project>) {
1133        self.projects.remove(&project.entity_id());
1134    }
1135
1136    fn handle_excerpt_store_event(
1137        &mut self,
1138        project_entity_id: EntityId,
1139        event: &RelatedExcerptStoreEvent,
1140    ) {
1141        if let Some(project_state) = self.projects.get(&project_entity_id) {
1142            if let Some(debug_tx) = project_state.debug_tx.clone() {
1143                match event {
1144                    RelatedExcerptStoreEvent::StartedRefresh => {
1145                        debug_tx
1146                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1147                                ContextRetrievalStartedDebugEvent {
1148                                    project_entity_id: project_entity_id,
1149                                    timestamp: Instant::now(),
1150                                    search_prompt: String::new(),
1151                                },
1152                            ))
1153                            .ok();
1154                    }
1155                    RelatedExcerptStoreEvent::FinishedRefresh {
1156                        cache_hit_count,
1157                        cache_miss_count,
1158                        mean_definition_latency,
1159                        max_definition_latency,
1160                    } => {
1161                        debug_tx
1162                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1163                                ContextRetrievalFinishedDebugEvent {
1164                                    project_entity_id: project_entity_id,
1165                                    timestamp: Instant::now(),
1166                                    metadata: vec![
1167                                        (
1168                                            "Cache Hits",
1169                                            format!(
1170                                                "{}/{}",
1171                                                cache_hit_count,
1172                                                cache_hit_count + cache_miss_count
1173                                            )
1174                                            .into(),
1175                                        ),
1176                                        (
1177                                            "Max LSP Time",
1178                                            format!("{} ms", max_definition_latency.as_millis())
1179                                                .into(),
1180                                        ),
1181                                        (
1182                                            "Mean LSP Time",
1183                                            format!("{} ms", mean_definition_latency.as_millis())
1184                                                .into(),
1185                                        ),
1186                                    ],
1187                                },
1188                            ))
1189                            .ok();
1190                    }
1191                }
1192            }
1193        }
1194    }
1195
1196    pub fn debug_info(
1197        &mut self,
1198        project: &Entity<Project>,
1199        cx: &mut Context<Self>,
1200    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1201        let project_state = self.get_or_init_project(project, cx);
1202        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1203        project_state.debug_tx = Some(debug_watch_tx);
1204        debug_watch_rx
1205    }
1206
1207    fn handle_project_event(
1208        &mut self,
1209        project: Entity<Project>,
1210        event: &project::Event,
1211        cx: &mut Context<Self>,
1212    ) {
1213        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1214            return;
1215        }
1216        // TODO [zeta2] init with recent paths
1217        match event {
1218            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1219                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1220                    return;
1221                };
1222                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1223                if let Some(path) = path {
1224                    if let Some(ix) = project_state
1225                        .recent_paths
1226                        .iter()
1227                        .position(|probe| probe == &path)
1228                    {
1229                        project_state.recent_paths.remove(ix);
1230                    }
1231                    project_state.recent_paths.push_front(path);
1232                }
1233            }
1234            project::Event::DiagnosticsUpdated { .. } => {
1235                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1236                    self.refresh_prediction_from_diagnostics(
1237                        project,
1238                        DiagnosticSearchScope::Global,
1239                        cx,
1240                    );
1241                }
1242            }
1243            _ => (),
1244        }
1245    }
1246
1247    fn register_buffer_impl<'a>(
1248        project_state: &'a mut ProjectState,
1249        buffer: &Entity<Buffer>,
1250        project: &Entity<Project>,
1251        cx: &mut Context<Self>,
1252    ) -> &'a mut RegisteredBuffer {
1253        let buffer_id = buffer.entity_id();
1254
1255        if let Some(file) = buffer.read(cx).file() {
1256            let worktree_id = file.worktree_id(cx);
1257            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1258                project_state
1259                    .license_detection_watchers
1260                    .entry(worktree_id)
1261                    .or_insert_with(|| {
1262                        let project_entity_id = project.entity_id();
1263                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1264                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1265                            else {
1266                                return;
1267                            };
1268                            project_state
1269                                .license_detection_watchers
1270                                .remove(&worktree_id);
1271                        })
1272                        .detach();
1273                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1274                    });
1275            }
1276        }
1277
1278        match project_state.registered_buffers.entry(buffer_id) {
1279            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1280            hash_map::Entry::Vacant(entry) => {
1281                let buf = buffer.read(cx);
1282                let snapshot = buf.text_snapshot();
1283                let file = buf.file().cloned();
1284                let project_entity_id = project.entity_id();
1285                entry.insert(RegisteredBuffer {
1286                    snapshot,
1287                    file,
1288                    last_position: None,
1289                    pending_predictions: Vec::new(),
1290                    _subscriptions: [
1291                        cx.subscribe(buffer, {
1292                            let project = project.downgrade();
1293                            move |this, buffer, event, cx| {
1294                                if let language::BufferEvent::Edited { is_local } = event
1295                                    && let Some(project) = project.upgrade()
1296                                {
1297                                    this.report_changes_for_buffer(
1298                                        &buffer, &project, false, *is_local, cx,
1299                                    );
1300                                }
1301                            }
1302                        }),
1303                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1304                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1305                            else {
1306                                return;
1307                            };
1308                            project_state.registered_buffers.remove(&buffer_id);
1309                        }),
1310                    ],
1311                })
1312            }
1313        }
1314    }
1315
1316    fn report_changes_for_buffer(
1317        &mut self,
1318        buffer: &Entity<Buffer>,
1319        project: &Entity<Project>,
1320        is_predicted: bool,
1321        is_local: bool,
1322        cx: &mut Context<Self>,
1323    ) {
1324        let project_state = self.get_or_init_project(project, cx);
1325        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1326
1327        let buf = buffer.read(cx);
1328        let new_file = buf.file().cloned();
1329        let new_snapshot = buf.text_snapshot();
1330        if new_snapshot.version == registered_buffer.snapshot.version {
1331            return;
1332        }
1333        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1334        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1335        let mut num_edits = 0usize;
1336        let mut total_deleted = 0usize;
1337        let mut total_inserted = 0usize;
1338        let mut edit_range: Option<Range<Anchor>> = None;
1339        let mut last_offset: Option<usize> = None;
1340        let now = cx.background_executor().now();
1341
1342        for (edit, anchor_range) in
1343            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1344        {
1345            num_edits += 1;
1346            total_deleted += edit.old.len();
1347            total_inserted += edit.new.len();
1348            edit_range = Some(match edit_range {
1349                None => anchor_range,
1350                Some(acc) => acc.start..anchor_range.end,
1351            });
1352            last_offset = Some(edit.new.end);
1353        }
1354
1355        let Some(edit_range) = edit_range else {
1356            return;
1357        };
1358
1359        for pending_prediction in &mut registered_buffer.pending_predictions {
1360            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1361                pending_prediction.last_edit_at = now;
1362            }
1363        }
1364
1365        let include_in_history = is_local
1366            || collaborator_edit_overlaps_locality_region(
1367                project_state,
1368                project,
1369                buffer,
1370                &buf.snapshot(),
1371                &edit_range,
1372                cx,
1373            );
1374
1375        if is_local {
1376            let action_type = match (total_deleted, total_inserted, num_edits) {
1377                (0, ins, n) if ins == n => UserActionType::InsertChar,
1378                (0, _, _) => UserActionType::InsertSelection,
1379                (del, 0, n) if del == n => UserActionType::DeleteChar,
1380                (_, 0, _) => UserActionType::DeleteSelection,
1381                (_, ins, n) if ins == n => UserActionType::InsertChar,
1382                (_, _, _) => UserActionType::InsertSelection,
1383            };
1384
1385            if let Some(offset) = last_offset {
1386                let point = new_snapshot.offset_to_point(offset);
1387                let timestamp_epoch_ms = SystemTime::now()
1388                    .duration_since(UNIX_EPOCH)
1389                    .map(|d| d.as_millis() as u64)
1390                    .unwrap_or(0);
1391                project_state.record_user_action(UserActionRecord {
1392                    action_type,
1393                    buffer_id: buffer.entity_id(),
1394                    line_number: point.row,
1395                    offset,
1396                    timestamp_epoch_ms,
1397                });
1398            }
1399        }
1400
1401        if !include_in_history {
1402            return;
1403        }
1404
1405        let events = &mut project_state.events;
1406
1407        if let Some(last_event) = project_state.last_event.as_mut() {
1408            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1409                == last_event.new_snapshot.remote_id()
1410                && old_snapshot.version == last_event.new_snapshot.version;
1411
1412            let prediction_source_changed = is_predicted != last_event.predicted;
1413
1414            let should_coalesce = is_next_snapshot_of_same_buffer
1415                && !prediction_source_changed
1416                && lines_between_ranges(
1417                    &edit_range.to_point(&new_snapshot),
1418                    &last_event.latest_edit_range.to_point(&new_snapshot),
1419                ) <= CHANGE_GROUPING_LINE_SPAN;
1420
1421            if should_coalesce {
1422                let pause_elapsed = last_event
1423                    .last_edit_time
1424                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1425                    .unwrap_or(false);
1426                if pause_elapsed {
1427                    last_event.snapshot_after_last_editing_pause =
1428                        Some(last_event.new_snapshot.clone());
1429                    last_event.total_edit_range_at_last_pause_boundary =
1430                        Some(last_event.total_edit_range.clone());
1431                }
1432
1433                last_event.latest_edit_range = edit_range.clone();
1434                last_event.total_edit_range =
1435                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1436                last_event.new_snapshot = new_snapshot;
1437                last_event.last_edit_time = Some(now);
1438                return;
1439            }
1440        }
1441
1442        if let Some(event) = project_state.last_event.take() {
1443            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1444                if events.len() + 1 >= EVENT_COUNT_MAX {
1445                    events.pop_front();
1446                }
1447                events.push_back(event);
1448            }
1449        }
1450
1451        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1452
1453        project_state.last_event = Some(LastEvent {
1454            old_file,
1455            new_file,
1456            old_snapshot,
1457            new_snapshot,
1458            latest_edit_range: edit_range.clone(),
1459            total_edit_range: edit_range,
1460            total_edit_range_at_last_pause_boundary: None,
1461            predicted: is_predicted,
1462            snapshot_after_last_editing_pause: None,
1463            last_edit_time: Some(now),
1464        });
1465    }
1466
1467    fn prediction_at(
1468        &mut self,
1469        buffer: &Entity<Buffer>,
1470        position: Option<language::Anchor>,
1471        project: &Entity<Project>,
1472        cx: &App,
1473    ) -> Option<BufferEditPrediction<'_>> {
1474        let project_state = self.projects.get_mut(&project.entity_id())?;
1475        if let Some(position) = position
1476            && let Some(buffer) = project_state
1477                .registered_buffers
1478                .get_mut(&buffer.entity_id())
1479        {
1480            buffer.last_position = Some(position);
1481        }
1482
1483        let CurrentEditPrediction {
1484            requested_by,
1485            prediction,
1486            ..
1487        } = project_state.current_prediction.as_ref()?;
1488
1489        if prediction.targets_buffer(buffer.read(cx)) {
1490            Some(BufferEditPrediction::Local { prediction })
1491        } else {
1492            let show_jump = match requested_by {
1493                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1494                    requested_by_buffer_id == &buffer.entity_id()
1495                }
1496                PredictionRequestedBy::DiagnosticsUpdate => true,
1497            };
1498
1499            if show_jump {
1500                Some(BufferEditPrediction::Jump { prediction })
1501            } else {
1502                None
1503            }
1504        }
1505    }
1506
1507    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1508        let Some(current_prediction) = self
1509            .projects
1510            .get_mut(&project.entity_id())
1511            .and_then(|project_state| project_state.current_prediction.take())
1512        else {
1513            return;
1514        };
1515
1516        self.report_changes_for_buffer(
1517            &current_prediction.prediction.buffer,
1518            project,
1519            true,
1520            true,
1521            cx,
1522        );
1523
1524        // can't hold &mut project_state ref across report_changes_for_buffer_call
1525        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1526            return;
1527        };
1528
1529        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1530            project_state.cancel_pending_prediction(pending_prediction, cx);
1531        }
1532
1533        match self.edit_prediction_model {
1534            EditPredictionModel::Sweep => {
1535                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1536            }
1537            EditPredictionModel::Mercury => {
1538                mercury::edit_prediction_accepted(
1539                    current_prediction.prediction.id,
1540                    self.client.http_client(),
1541                    cx,
1542                );
1543            }
1544            EditPredictionModel::Zeta => {
1545                let is_cloud = !matches!(
1546                    all_language_settings(None, cx).edit_predictions.provider,
1547                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1548                );
1549                if is_cloud {
1550                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1551                }
1552            }
1553            EditPredictionModel::Fim { .. } => {}
1554        }
1555    }
1556
1557    async fn handle_rejected_predictions(
1558        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1559        client: Arc<Client>,
1560        llm_token: LlmApiToken,
1561        app_version: Version,
1562        background_executor: BackgroundExecutor,
1563    ) {
1564        let mut rx = std::pin::pin!(rx.peekable());
1565        let mut batched = Vec::new();
1566
1567        while let Some(EditPredictionRejectionPayload {
1568            rejection,
1569            organization_id,
1570        }) = rx.next().await
1571        {
1572            batched.push(rejection);
1573
1574            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1575                select_biased! {
1576                    next = rx.as_mut().peek().fuse() => {
1577                        if next.is_some() {
1578                            continue;
1579                        }
1580                    }
1581                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1582                }
1583            }
1584
1585            let url = client
1586                .http_client()
1587                .build_zed_llm_url("/predict_edits/reject", &[])
1588                .unwrap();
1589
1590            let flush_count = batched
1591                .len()
1592                // in case items have accumulated after failure
1593                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1594            let start = batched.len() - flush_count;
1595
1596            let body = RejectEditPredictionsBodyRef {
1597                rejections: &batched[start..],
1598            };
1599
1600            let result = Self::send_api_request::<()>(
1601                |builder| {
1602                    let req = builder
1603                        .uri(url.as_ref())
1604                        .body(serde_json::to_string(&body)?.into());
1605                    anyhow::Ok(req?)
1606                },
1607                client.clone(),
1608                llm_token.clone(),
1609                organization_id,
1610                app_version.clone(),
1611                true,
1612            )
1613            .await;
1614
1615            if result.log_err().is_some() {
1616                batched.drain(start..);
1617            }
1618        }
1619    }
1620
1621    async fn run_settled_predictions_worker(
1622        this: WeakEntity<Self>,
1623        mut rx: UnboundedReceiver<Instant>,
1624        cx: &mut AsyncApp,
1625    ) {
1626        let mut next_wake_time: Option<Instant> = None;
1627        loop {
1628            let now = cx.background_executor().now();
1629            if let Some(wake_time) = next_wake_time.take() {
1630                cx.background_executor()
1631                    .timer(wake_time.duration_since(now))
1632                    .await;
1633            } else {
1634                let Some(new_enqueue_time) = rx.next().await else {
1635                    break;
1636                };
1637                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1638                while rx.next().now_or_never().flatten().is_some() {}
1639                continue;
1640            }
1641
1642            let Some(this) = this.upgrade() else {
1643                break;
1644            };
1645
1646            let now = cx.background_executor().now();
1647
1648            let mut oldest_edited_at = None;
1649
1650            this.update(cx, |this, _| {
1651                for (_, project_state) in this.projects.iter_mut() {
1652                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1653                        registered_buffer
1654                            .pending_predictions
1655                            .retain_mut(|pending_prediction| {
1656                                let age =
1657                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1658                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1659                                    return false;
1660                                }
1661
1662                                let quiet_for =
1663                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1664                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1665                                    let settled_editable_region = registered_buffer
1666                                        .snapshot
1667                                        .text_for_range(
1668                                            pending_prediction.editable_anchor_range.clone(),
1669                                        )
1670                                        .collect::<String>();
1671
1672                                    #[cfg(test)]
1673                                    if let Some(callback) = &this.settled_event_callback {
1674                                        callback(
1675                                            pending_prediction.request_id.clone(),
1676                                            settled_editable_region.clone(),
1677                                        );
1678                                    }
1679
1680                                    telemetry::event!(
1681                                        EDIT_PREDICTION_SETTLED_EVENT,
1682                                        request_id = pending_prediction.request_id.0.clone(),
1683                                        settled_editable_region,
1684                                        example = pending_prediction.example.take(),
1685                                    );
1686
1687                                    return false;
1688                                }
1689
1690                                if oldest_edited_at
1691                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1692                                {
1693                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1694                                }
1695
1696                                true
1697                            });
1698                    }
1699                }
1700            });
1701
1702            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1703        }
1704    }
1705
1706    pub(crate) fn enqueue_settled_prediction(
1707        &mut self,
1708        request_id: EditPredictionId,
1709        project: &Entity<Project>,
1710        edited_buffer: &Entity<Buffer>,
1711        edited_buffer_snapshot: &BufferSnapshot,
1712        editable_offset_range: Range<usize>,
1713        example: Option<ExampleSpec>,
1714        cx: &mut Context<Self>,
1715    ) {
1716        let this = &mut *self;
1717        let project_state = this.get_or_init_project(project, cx);
1718        if let Some(buffer) = project_state
1719            .registered_buffers
1720            .get_mut(&edited_buffer.entity_id())
1721        {
1722            let now = cx.background_executor().now();
1723            buffer.pending_predictions.push(PendingSettledPrediction {
1724                request_id: request_id,
1725                editable_anchor_range: edited_buffer_snapshot
1726                    .anchor_range_around(editable_offset_range),
1727                example,
1728                enqueued_at: now,
1729                last_edit_at: now,
1730            });
1731            this.settled_predictions_tx.unbounded_send(now).ok();
1732        }
1733    }
1734
1735    fn reject_current_prediction(
1736        &mut self,
1737        reason: EditPredictionRejectReason,
1738        project: &Entity<Project>,
1739        cx: &App,
1740    ) {
1741        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1742            project_state.pending_predictions.clear();
1743            if let Some(prediction) = project_state.current_prediction.take() {
1744                let model_version = prediction.prediction.model_version.clone();
1745                self.reject_prediction(
1746                    prediction.prediction.id,
1747                    reason,
1748                    prediction.was_shown,
1749                    model_version,
1750                    cx,
1751                );
1752            }
1753        };
1754    }
1755
1756    fn did_show_current_prediction(
1757        &mut self,
1758        project: &Entity<Project>,
1759        display_type: edit_prediction_types::SuggestionDisplayType,
1760        cx: &mut Context<Self>,
1761    ) {
1762        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1763            return;
1764        };
1765
1766        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1767            return;
1768        };
1769
1770        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1771        let previous_shown_with = current_prediction.shown_with;
1772
1773        if previous_shown_with.is_none() || !is_jump {
1774            current_prediction.shown_with = Some(display_type);
1775        }
1776
1777        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1778
1779        if is_first_non_jump_show {
1780            current_prediction.was_shown = true;
1781        }
1782
1783        let display_type_changed = previous_shown_with != Some(display_type);
1784
1785        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1786            sweep_ai::edit_prediction_shown(
1787                &self.sweep_ai,
1788                self.client.clone(),
1789                &current_prediction.prediction,
1790                display_type,
1791                cx,
1792            );
1793        }
1794
1795        if is_first_non_jump_show {
1796            self.shown_predictions
1797                .push_front(current_prediction.prediction.clone());
1798            if self.shown_predictions.len() > 50 {
1799                let completion = self.shown_predictions.pop_back().unwrap();
1800                self.rated_predictions.remove(&completion.id);
1801            }
1802        }
1803    }
1804
1805    fn reject_prediction(
1806        &mut self,
1807        prediction_id: EditPredictionId,
1808        reason: EditPredictionRejectReason,
1809        was_shown: bool,
1810        model_version: Option<String>,
1811        cx: &App,
1812    ) {
1813        match self.edit_prediction_model {
1814            EditPredictionModel::Zeta => {
1815                let is_cloud = !matches!(
1816                    all_language_settings(None, cx).edit_predictions.provider,
1817                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1818                );
1819
1820                if is_cloud {
1821                    let organization_id = self
1822                        .user_store
1823                        .read(cx)
1824                        .current_organization()
1825                        .map(|organization| organization.id.clone());
1826
1827                    self.reject_predictions_tx
1828                        .unbounded_send(EditPredictionRejectionPayload {
1829                            rejection: EditPredictionRejection {
1830                                request_id: prediction_id.to_string(),
1831                                reason,
1832                                was_shown,
1833                                model_version,
1834                            },
1835                            organization_id,
1836                        })
1837                        .log_err();
1838                }
1839            }
1840            EditPredictionModel::Mercury => {
1841                mercury::edit_prediction_rejected(
1842                    prediction_id,
1843                    was_shown,
1844                    reason,
1845                    self.client.http_client(),
1846                    cx,
1847                );
1848            }
1849            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1850        }
1851    }
1852
1853    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1854        self.projects
1855            .get(&project.entity_id())
1856            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1857    }
1858
1859    pub fn refresh_prediction_from_buffer(
1860        &mut self,
1861        project: Entity<Project>,
1862        buffer: Entity<Buffer>,
1863        position: language::Anchor,
1864        cx: &mut Context<Self>,
1865    ) {
1866        self.queue_prediction_refresh(
1867            project.clone(),
1868            PredictEditsRequestTrigger::Other,
1869            buffer.entity_id(),
1870            cx,
1871            move |this, cx| {
1872                let Some(request_task) = this
1873                    .update(cx, |this, cx| {
1874                        this.request_prediction(
1875                            &project,
1876                            &buffer,
1877                            position,
1878                            PredictEditsRequestTrigger::Other,
1879                            cx,
1880                        )
1881                    })
1882                    .log_err()
1883                else {
1884                    return Task::ready(anyhow::Ok(None));
1885                };
1886
1887                cx.spawn(async move |_cx| {
1888                    request_task.await.map(|prediction_result| {
1889                        prediction_result.map(|prediction_result| {
1890                            (
1891                                prediction_result,
1892                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1893                            )
1894                        })
1895                    })
1896                })
1897            },
1898        )
1899    }
1900
1901    pub fn refresh_prediction_from_diagnostics(
1902        &mut self,
1903        project: Entity<Project>,
1904        scope: DiagnosticSearchScope,
1905        cx: &mut Context<Self>,
1906    ) {
1907        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1908            return;
1909        }
1910
1911        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1912            return;
1913        };
1914
1915        // Prefer predictions from buffer
1916        if project_state.current_prediction.is_some() {
1917            log::debug!(
1918                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1919            );
1920            return;
1921        }
1922
1923        self.queue_prediction_refresh(
1924            project.clone(),
1925            PredictEditsRequestTrigger::Diagnostics,
1926            project.entity_id(),
1927            cx,
1928            move |this, cx| {
1929                let Some((active_buffer, snapshot, cursor_point)) = this
1930                    .read_with(cx, |this, cx| {
1931                        let project_state = this.projects.get(&project.entity_id())?;
1932                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1933                        let snapshot = buffer.read(cx).snapshot();
1934
1935                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1936                            return None;
1937                        }
1938
1939                        let cursor_point = position
1940                            .map(|pos| pos.to_point(&snapshot))
1941                            .unwrap_or_default();
1942
1943                        Some((buffer, snapshot, cursor_point))
1944                    })
1945                    .log_err()
1946                    .flatten()
1947                else {
1948                    return Task::ready(anyhow::Ok(None));
1949                };
1950
1951                cx.spawn(async move |cx| {
1952                    let diagnostic_search_range = match scope {
1953                        DiagnosticSearchScope::Local => {
1954                            let diagnostic_search_start =
1955                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1956                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1957                            Point::new(diagnostic_search_start, 0)
1958                                ..Point::new(diagnostic_search_end, 0)
1959                        }
1960                        DiagnosticSearchScope::Global => Default::default(),
1961                    };
1962
1963                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1964                        active_buffer,
1965                        &snapshot,
1966                        diagnostic_search_range,
1967                        cursor_point,
1968                        &project,
1969                        cx,
1970                    )
1971                    .await?
1972                    else {
1973                        return anyhow::Ok(None);
1974                    };
1975
1976                    let Some(prediction_result) = this
1977                        .update(cx, |this, cx| {
1978                            this.request_prediction(
1979                                &project,
1980                                &jump_buffer,
1981                                jump_position,
1982                                PredictEditsRequestTrigger::Diagnostics,
1983                                cx,
1984                            )
1985                        })?
1986                        .await?
1987                    else {
1988                        return anyhow::Ok(None);
1989                    };
1990
1991                    this.update(cx, |this, cx| {
1992                        Some((
1993                            if this
1994                                .get_or_init_project(&project, cx)
1995                                .current_prediction
1996                                .is_none()
1997                            {
1998                                prediction_result
1999                            } else {
2000                                EditPredictionResult {
2001                                    id: prediction_result.id,
2002                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2003                                }
2004                            },
2005                            PredictionRequestedBy::DiagnosticsUpdate,
2006                        ))
2007                    })
2008                })
2009            },
2010        );
2011    }
2012
2013    fn predictions_enabled_at(
2014        snapshot: &BufferSnapshot,
2015        position: Option<language::Anchor>,
2016        cx: &App,
2017    ) -> bool {
2018        let file = snapshot.file();
2019        let all_settings = all_language_settings(file, cx);
2020        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2021            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2022        {
2023            return false;
2024        }
2025
2026        if let Some(last_position) = position {
2027            let settings = snapshot.settings_at(last_position, cx);
2028
2029            if !settings.edit_predictions_disabled_in.is_empty()
2030                && let Some(scope) = snapshot.language_scope_at(last_position)
2031                && let Some(scope_name) = scope.override_name()
2032                && settings
2033                    .edit_predictions_disabled_in
2034                    .iter()
2035                    .any(|s| s == scope_name)
2036            {
2037                return false;
2038            }
2039        }
2040
2041        true
2042    }
2043
2044    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2045}
2046
2047fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2048    match provider {
2049        EditPredictionProvider::Zed
2050        | EditPredictionProvider::Sweep
2051        | EditPredictionProvider::Mercury
2052        | EditPredictionProvider::Ollama
2053        | EditPredictionProvider::OpenAiCompatibleApi
2054        | EditPredictionProvider::Experimental(_) => true,
2055        EditPredictionProvider::None
2056        | EditPredictionProvider::Copilot
2057        | EditPredictionProvider::Codestral => false,
2058    }
2059}
2060
2061impl EditPredictionStore {
2062    fn queue_prediction_refresh(
2063        &mut self,
2064        project: Entity<Project>,
2065        request_trigger: PredictEditsRequestTrigger,
2066        throttle_entity: EntityId,
2067        cx: &mut Context<Self>,
2068        do_refresh: impl FnOnce(
2069            WeakEntity<Self>,
2070            &mut AsyncApp,
2071        )
2072            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2073        + 'static,
2074    ) {
2075        fn select_throttle(
2076            project_state: &mut ProjectState,
2077            request_trigger: PredictEditsRequestTrigger,
2078        ) -> &mut Option<(EntityId, Instant)> {
2079            match request_trigger {
2080                PredictEditsRequestTrigger::Diagnostics => {
2081                    &mut project_state.last_jump_prediction_refresh
2082                }
2083                _ => &mut project_state.last_edit_prediction_refresh,
2084            }
2085        }
2086
2087        let (needs_acceptance_tracking, max_pending_predictions) =
2088            match all_language_settings(None, cx).edit_predictions.provider {
2089                EditPredictionProvider::Zed
2090                | EditPredictionProvider::Sweep
2091                | EditPredictionProvider::Mercury
2092                | EditPredictionProvider::Experimental(_) => (true, 2),
2093                EditPredictionProvider::Ollama => (false, 1),
2094                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2095                EditPredictionProvider::None
2096                | EditPredictionProvider::Copilot
2097                | EditPredictionProvider::Codestral => {
2098                    log::error!("queue_prediction_refresh called with non-store provider");
2099                    return;
2100                }
2101            };
2102
2103        let drop_on_cancel = !needs_acceptance_tracking;
2104        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2105        let project_state = self.get_or_init_project(&project, cx);
2106        let pending_prediction_id = project_state.next_pending_prediction_id;
2107        project_state.next_pending_prediction_id += 1;
2108        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2109
2110        let task = cx.spawn(async move |this, cx| {
2111            let throttle_wait = this
2112                .update(cx, |this, cx| {
2113                    let project_state = this.get_or_init_project(&project, cx);
2114                    let throttle = *select_throttle(project_state, request_trigger);
2115
2116                    throttle.and_then(|(last_entity, last_timestamp)| {
2117                        if throttle_entity != last_entity {
2118                            return None;
2119                        }
2120                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2121                    })
2122                })
2123                .ok()
2124                .flatten();
2125
2126            if let Some(timeout) = throttle_wait {
2127                cx.background_executor().timer(timeout).await;
2128            }
2129
2130            // If this task was cancelled before the throttle timeout expired,
2131            // do not perform a request. Also skip if another task already
2132            // proceeded since we were enqueued (duplicate).
2133            let mut is_cancelled = true;
2134            this.update(cx, |this, cx| {
2135                let project_state = this.get_or_init_project(&project, cx);
2136                let was_cancelled = project_state
2137                    .cancelled_predictions
2138                    .remove(&pending_prediction_id);
2139                if was_cancelled {
2140                    return;
2141                }
2142
2143                // Another request has been already sent since this was enqueued
2144                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2145                    return;
2146                }
2147
2148                let new_refresh = (throttle_entity, Instant::now());
2149                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2150                is_cancelled = false;
2151            })
2152            .ok();
2153            if is_cancelled {
2154                return None;
2155            }
2156
2157            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2158            let new_prediction_id = new_prediction_result
2159                .as_ref()
2160                .map(|(prediction, _)| prediction.id.clone());
2161
2162            // When a prediction completes, remove it from the pending list, and cancel
2163            // any pending predictions that were enqueued before it.
2164            this.update(cx, |this, cx| {
2165                let project_state = this.get_or_init_project(&project, cx);
2166
2167                let is_cancelled = project_state
2168                    .cancelled_predictions
2169                    .remove(&pending_prediction_id);
2170
2171                let new_current_prediction = if !is_cancelled
2172                    && let Some((prediction_result, requested_by)) = new_prediction_result
2173                {
2174                    match prediction_result.prediction {
2175                        Ok(prediction) => {
2176                            let new_prediction = CurrentEditPrediction {
2177                                requested_by,
2178                                prediction,
2179                                was_shown: false,
2180                                shown_with: None,
2181                            };
2182
2183                            if let Some(current_prediction) =
2184                                project_state.current_prediction.as_ref()
2185                            {
2186                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2187                                {
2188                                    this.reject_current_prediction(
2189                                        EditPredictionRejectReason::Replaced,
2190                                        &project,
2191                                        cx,
2192                                    );
2193
2194                                    Some(new_prediction)
2195                                } else {
2196                                    this.reject_prediction(
2197                                        new_prediction.prediction.id,
2198                                        EditPredictionRejectReason::CurrentPreferred,
2199                                        false,
2200                                        new_prediction.prediction.model_version,
2201                                        cx,
2202                                    );
2203                                    None
2204                                }
2205                            } else {
2206                                Some(new_prediction)
2207                            }
2208                        }
2209                        Err(reject_reason) => {
2210                            this.reject_prediction(
2211                                prediction_result.id,
2212                                reject_reason,
2213                                false,
2214                                None,
2215                                cx,
2216                            );
2217                            None
2218                        }
2219                    }
2220                } else {
2221                    None
2222                };
2223
2224                let project_state = this.get_or_init_project(&project, cx);
2225
2226                if let Some(new_prediction) = new_current_prediction {
2227                    project_state.current_prediction = Some(new_prediction);
2228                }
2229
2230                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2231                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2232                    if pending_prediction.id == pending_prediction_id {
2233                        pending_predictions.remove(ix);
2234                        for pending_prediction in pending_predictions.drain(0..ix) {
2235                            project_state.cancel_pending_prediction(pending_prediction, cx)
2236                        }
2237                        break;
2238                    }
2239                }
2240                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2241                cx.notify();
2242            })
2243            .ok();
2244
2245            new_prediction_id
2246        });
2247
2248        if project_state.pending_predictions.len() < max_pending_predictions {
2249            project_state.pending_predictions.push(PendingPrediction {
2250                id: pending_prediction_id,
2251                task,
2252                drop_on_cancel,
2253            });
2254        } else {
2255            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2256            project_state.pending_predictions.push(PendingPrediction {
2257                id: pending_prediction_id,
2258                task,
2259                drop_on_cancel,
2260            });
2261            project_state.cancel_pending_prediction(pending_prediction, cx);
2262        }
2263    }
2264
2265    pub fn request_prediction(
2266        &mut self,
2267        project: &Entity<Project>,
2268        active_buffer: &Entity<Buffer>,
2269        position: language::Anchor,
2270        trigger: PredictEditsRequestTrigger,
2271        cx: &mut Context<Self>,
2272    ) -> Task<Result<Option<EditPredictionResult>>> {
2273        self.request_prediction_internal(
2274            project.clone(),
2275            active_buffer.clone(),
2276            position,
2277            trigger,
2278            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2279            cx,
2280        )
2281    }
2282
2283    fn request_prediction_internal(
2284        &mut self,
2285        project: Entity<Project>,
2286        active_buffer: Entity<Buffer>,
2287        position: language::Anchor,
2288        trigger: PredictEditsRequestTrigger,
2289        allow_jump: bool,
2290        cx: &mut Context<Self>,
2291    ) -> Task<Result<Option<EditPredictionResult>>> {
2292        self.get_or_init_project(&project, cx);
2293        let project_state = self.projects.get(&project.entity_id()).unwrap();
2294        let stored_events = project_state.events(cx);
2295        let has_events = !stored_events.is_empty();
2296        let events: Vec<Arc<zeta_prompt::Event>> =
2297            stored_events.iter().map(|e| e.event.clone()).collect();
2298        let debug_tx = project_state.debug_tx.clone();
2299
2300        let snapshot = active_buffer.read(cx).snapshot();
2301        let cursor_point = position.to_point(&snapshot);
2302        let current_offset = position.to_offset(&snapshot);
2303
2304        let mut user_actions: Vec<UserActionRecord> =
2305            project_state.user_actions.iter().cloned().collect();
2306
2307        if let Some(last_action) = user_actions.last() {
2308            if last_action.buffer_id == active_buffer.entity_id()
2309                && current_offset != last_action.offset
2310            {
2311                let timestamp_epoch_ms = SystemTime::now()
2312                    .duration_since(UNIX_EPOCH)
2313                    .map(|d| d.as_millis() as u64)
2314                    .unwrap_or(0);
2315                user_actions.push(UserActionRecord {
2316                    action_type: UserActionType::CursorMovement,
2317                    buffer_id: active_buffer.entity_id(),
2318                    line_number: cursor_point.row,
2319                    offset: current_offset,
2320                    timestamp_epoch_ms,
2321                });
2322            }
2323        }
2324        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2325        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2326        let diagnostic_search_range =
2327            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2328
2329        let related_files = self.context_for_project(&project, cx);
2330
2331        let is_open_source = snapshot
2332            .file()
2333            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2334            && events.iter().all(|event| event.in_open_source_repo())
2335            && related_files.iter().all(|file| file.in_open_source_repo);
2336
2337        let can_collect_data = !cfg!(test)
2338            && is_open_source
2339            && self.is_data_collection_enabled(cx)
2340            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2341
2342        let recent_paths = project_state.recent_paths.clone();
2343
2344        let inputs = EditPredictionModelInput {
2345            project: project.clone(),
2346            buffer: active_buffer,
2347            snapshot,
2348            position,
2349            events,
2350            related_files,
2351            recent_paths,
2352            trigger,
2353            diagnostic_search_range: diagnostic_search_range,
2354            debug_tx,
2355            user_actions,
2356            can_collect_data,
2357            is_open_source,
2358        };
2359
2360        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2361
2362        let task = match self.edit_prediction_model {
2363            EditPredictionModel::Zeta => {
2364                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2365            }
2366            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2367            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2368            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2369        };
2370
2371        cx.spawn(async move |this, cx| {
2372            let prediction = task.await?;
2373
2374            // Only fall back to diagnostics-based prediction if we got a
2375            // the model had nothing to suggest for the buffer
2376            if prediction.is_none()
2377                && allow_jump
2378                && has_events
2379                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2380            {
2381                this.update(cx, |this, cx| {
2382                    this.refresh_prediction_from_diagnostics(
2383                        project,
2384                        DiagnosticSearchScope::Local,
2385                        cx,
2386                    );
2387                })?;
2388                return anyhow::Ok(None);
2389            }
2390
2391            Ok(prediction)
2392        })
2393    }
2394
2395    pub(crate) async fn next_diagnostic_location(
2396        active_buffer: Entity<Buffer>,
2397        active_buffer_snapshot: &BufferSnapshot,
2398        active_buffer_diagnostic_search_range: Range<Point>,
2399        active_buffer_cursor_point: Point,
2400        project: &Entity<Project>,
2401        cx: &mut AsyncApp,
2402    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2403        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2404            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2405            .flat_map(|(_, _, _, selections)| {
2406                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2407            })
2408            .collect();
2409
2410        let mut jump_location = active_buffer_snapshot
2411            .diagnostic_groups(None)
2412            .into_iter()
2413            .filter_map(|(_, group)| {
2414                let range = &group.entries[group.primary_ix]
2415                    .range
2416                    .to_point(&active_buffer_snapshot);
2417                if range.overlaps(&active_buffer_diagnostic_search_range) {
2418                    return None;
2419                }
2420                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2421                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2422                });
2423                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2424                    <= DIAGNOSTIC_LINES_RANGE;
2425                if near_collaborator && !near_local {
2426                    return None;
2427                }
2428                Some(range.start)
2429            })
2430            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2431            .map(|position| {
2432                (
2433                    active_buffer.clone(),
2434                    active_buffer_snapshot.anchor_before(position),
2435                )
2436            });
2437
2438        if jump_location.is_none() {
2439            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2440                let file = buffer.file()?;
2441
2442                Some(ProjectPath {
2443                    worktree_id: file.worktree_id(cx),
2444                    path: file.path().clone(),
2445                })
2446            });
2447
2448            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2449                project
2450                    .diagnostic_summaries(false, cx)
2451                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2452                    .map(|(path, _, _)| {
2453                        let shared_prefix = path
2454                            .path
2455                            .components()
2456                            .zip(
2457                                active_buffer_path
2458                                    .as_ref()
2459                                    .map(|p| p.path.components())
2460                                    .unwrap_or_default(),
2461                            )
2462                            .take_while(|(a, b)| a == b)
2463                            .count();
2464                        (path, shared_prefix)
2465                    })
2466                    .collect()
2467            });
2468
2469            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2470
2471            for (path, _) in candidates {
2472                let candidate_buffer = project
2473                    .update(cx, |project, cx| project.open_buffer(path, cx))
2474                    .await?;
2475
2476                let (has_collaborators, diagnostic_position) =
2477                    candidate_buffer.read_with(cx, |buffer, _cx| {
2478                        let snapshot = buffer.snapshot();
2479                        let has_collaborators = snapshot
2480                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2481                            .next()
2482                            .is_some();
2483                        let position = buffer
2484                            .buffer_diagnostics(None)
2485                            .into_iter()
2486                            .min_by_key(|entry| entry.diagnostic.severity)
2487                            .map(|entry| entry.range.start);
2488                        (has_collaborators, position)
2489                    });
2490
2491                if has_collaborators {
2492                    continue;
2493                }
2494
2495                if let Some(position) = diagnostic_position {
2496                    jump_location = Some((candidate_buffer, position));
2497                    break;
2498                }
2499            }
2500        }
2501
2502        anyhow::Ok(jump_location)
2503    }
2504
2505    async fn send_raw_llm_request(
2506        request: RawCompletionRequest,
2507        client: Arc<Client>,
2508        custom_url: Option<Arc<Url>>,
2509        llm_token: LlmApiToken,
2510        organization_id: Option<OrganizationId>,
2511        app_version: Version,
2512    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2513        let url = if let Some(custom_url) = custom_url {
2514            custom_url.as_ref().clone()
2515        } else {
2516            client
2517                .http_client()
2518                .build_zed_llm_url("/predict_edits/raw", &[])?
2519        };
2520
2521        Self::send_api_request(
2522            |builder| {
2523                let req = builder
2524                    .uri(url.as_ref())
2525                    .body(serde_json::to_string(&request)?.into());
2526                Ok(req?)
2527            },
2528            client,
2529            llm_token,
2530            organization_id,
2531            app_version,
2532            true,
2533        )
2534        .await
2535    }
2536
2537    pub(crate) async fn send_v3_request(
2538        input: ZetaPromptInput,
2539        client: Arc<Client>,
2540        llm_token: LlmApiToken,
2541        organization_id: Option<OrganizationId>,
2542        app_version: Version,
2543        trigger: PredictEditsRequestTrigger,
2544    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2545        let url = client
2546            .http_client()
2547            .build_zed_llm_url("/predict_edits/v3", &[])?;
2548
2549        let request = PredictEditsV3Request { input, trigger };
2550
2551        let json_bytes = serde_json::to_vec(&request)?;
2552        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2553
2554        Self::send_api_request(
2555            |builder| {
2556                let req = builder
2557                    .uri(url.as_ref())
2558                    .header("Content-Encoding", "zstd")
2559                    .body(compressed.clone().into());
2560                Ok(req?)
2561            },
2562            client,
2563            llm_token,
2564            organization_id,
2565            app_version,
2566            true,
2567        )
2568        .await
2569    }
2570
2571    async fn send_api_request<Res>(
2572        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2573        client: Arc<Client>,
2574        llm_token: LlmApiToken,
2575        organization_id: Option<OrganizationId>,
2576        app_version: Version,
2577        require_auth: bool,
2578    ) -> Result<(Res, Option<EditPredictionUsage>)>
2579    where
2580        Res: DeserializeOwned,
2581    {
2582        let http_client = client.http_client();
2583
2584        let mut token = if require_auth {
2585            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2586        } else {
2587            llm_token
2588                .acquire(&client, organization_id.clone())
2589                .await
2590                .ok()
2591        };
2592        let mut did_retry = false;
2593
2594        loop {
2595            let request_builder = http_client::Request::builder().method(Method::POST);
2596
2597            let mut request_builder = request_builder
2598                .header("Content-Type", "application/json")
2599                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2600
2601            // Only add Authorization header if we have a token
2602            if let Some(ref token_value) = token {
2603                request_builder =
2604                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2605            }
2606
2607            let request = build(request_builder)?;
2608
2609            let mut response = http_client.send(request).await?;
2610
2611            if let Some(minimum_required_version) = response
2612                .headers()
2613                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2614                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2615            {
2616                anyhow::ensure!(
2617                    app_version >= minimum_required_version,
2618                    ZedUpdateRequiredError {
2619                        minimum_version: minimum_required_version
2620                    }
2621                );
2622            }
2623
2624            if response.status().is_success() {
2625                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2626
2627                let mut body = Vec::new();
2628                response.body_mut().read_to_end(&mut body).await?;
2629                return Ok((serde_json::from_slice(&body)?, usage));
2630            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2631                did_retry = true;
2632                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2633            } else {
2634                let mut body = String::new();
2635                response.body_mut().read_to_string(&mut body).await?;
2636                anyhow::bail!(
2637                    "Request failed with status: {:?}\nBody: {}",
2638                    response.status(),
2639                    body
2640                );
2641            }
2642        }
2643    }
2644
2645    pub fn refresh_context(
2646        &mut self,
2647        project: &Entity<Project>,
2648        buffer: &Entity<language::Buffer>,
2649        cursor_position: language::Anchor,
2650        cx: &mut Context<Self>,
2651    ) {
2652        self.get_or_init_project(project, cx)
2653            .context
2654            .update(cx, |store, cx| {
2655                store.refresh(buffer.clone(), cursor_position, cx);
2656            });
2657    }
2658
2659    #[cfg(feature = "cli-support")]
2660    pub fn set_context_for_buffer(
2661        &mut self,
2662        project: &Entity<Project>,
2663        related_files: Vec<RelatedFile>,
2664        cx: &mut Context<Self>,
2665    ) {
2666        self.get_or_init_project(project, cx)
2667            .context
2668            .update(cx, |store, cx| {
2669                store.set_related_files(related_files, cx);
2670            });
2671    }
2672
2673    #[cfg(feature = "cli-support")]
2674    pub fn set_recent_paths_for_project(
2675        &mut self,
2676        project: &Entity<Project>,
2677        paths: impl IntoIterator<Item = project::ProjectPath>,
2678        cx: &mut Context<Self>,
2679    ) {
2680        let project_state = self.get_or_init_project(project, cx);
2681        project_state.recent_paths = paths.into_iter().collect();
2682    }
2683
2684    fn is_file_open_source(
2685        &self,
2686        project: &Entity<Project>,
2687        file: &Arc<dyn File>,
2688        cx: &App,
2689    ) -> bool {
2690        if !file.is_local() || file.is_private() {
2691            return false;
2692        }
2693        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2694            return false;
2695        };
2696        project_state
2697            .license_detection_watchers
2698            .get(&file.worktree_id(cx))
2699            .as_ref()
2700            .is_some_and(|watcher| watcher.is_project_open_source())
2701    }
2702
2703    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2704        self.data_collection_choice.is_enabled(cx)
2705    }
2706
2707    fn load_data_collection_choice() -> DataCollectionChoice {
2708        let choice = KEY_VALUE_STORE
2709            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2710            .log_err()
2711            .flatten();
2712
2713        match choice.as_deref() {
2714            Some("true") => DataCollectionChoice::Enabled,
2715            Some("false") => DataCollectionChoice::Disabled,
2716            Some(_) => {
2717                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2718                DataCollectionChoice::NotAnswered
2719            }
2720            None => DataCollectionChoice::NotAnswered,
2721        }
2722    }
2723
2724    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2725        self.data_collection_choice = self.data_collection_choice.toggle();
2726        let new_choice = self.data_collection_choice;
2727        let is_enabled = new_choice.is_enabled(cx);
2728        db::write_and_log(cx, move || {
2729            KEY_VALUE_STORE.write_kvp(
2730                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2731                is_enabled.to_string(),
2732            )
2733        });
2734    }
2735
2736    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2737        self.shown_predictions.iter()
2738    }
2739
2740    pub fn shown_completions_len(&self) -> usize {
2741        self.shown_predictions.len()
2742    }
2743
2744    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2745        self.rated_predictions.contains(id)
2746    }
2747
2748    pub fn rate_prediction(
2749        &mut self,
2750        prediction: &EditPrediction,
2751        rating: EditPredictionRating,
2752        feedback: String,
2753        cx: &mut Context<Self>,
2754    ) {
2755        let organization = self.user_store.read(cx).current_organization();
2756
2757        self.rated_predictions.insert(prediction.id.clone());
2758
2759        cx.background_spawn({
2760            let client = self.client.clone();
2761            let prediction_id = prediction.id.to_string();
2762            let inputs = serde_json::to_value(&prediction.inputs);
2763            let output = prediction
2764                .edit_preview
2765                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2766            async move {
2767                client
2768                    .cloud_client()
2769                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2770                        organization_id: organization.map(|organization| organization.id.clone()),
2771                        request_id: prediction_id,
2772                        rating: match rating {
2773                            EditPredictionRating::Positive => "positive".to_string(),
2774                            EditPredictionRating::Negative => "negative".to_string(),
2775                        },
2776                        inputs: inputs?,
2777                        output,
2778                        feedback,
2779                    })
2780                    .await?;
2781
2782                anyhow::Ok(())
2783            }
2784        })
2785        .detach_and_log_err(cx);
2786
2787        cx.notify();
2788    }
2789}
2790
2791fn collaborator_edit_overlaps_locality_region(
2792    project_state: &ProjectState,
2793    project: &Entity<Project>,
2794    buffer: &Entity<Buffer>,
2795    snapshot: &BufferSnapshot,
2796    edit_range: &Range<Anchor>,
2797    cx: &App,
2798) -> bool {
2799    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2800        return false;
2801    };
2802
2803    if active_buffer.entity_id() != buffer.entity_id() {
2804        return false;
2805    }
2806
2807    let locality_point_range = expand_context_syntactically_then_linewise(
2808        snapshot,
2809        (position..position).to_point(snapshot),
2810        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2811    );
2812    let locality_anchor_range = snapshot.anchor_range_around(locality_point_range);
2813
2814    edit_range.overlaps(&locality_anchor_range, snapshot)
2815}
2816
2817fn merge_trailing_events_if_needed(
2818    events: &mut VecDeque<StoredEvent>,
2819    end_snapshot: &TextBufferSnapshot,
2820    latest_snapshot: &TextBufferSnapshot,
2821    latest_edit_range: &Range<Anchor>,
2822) {
2823    if let Some(last_event) = events.back() {
2824        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2825            return;
2826        }
2827        if !latest_snapshot
2828            .version
2829            .observed_all(&last_event.new_snapshot_version)
2830        {
2831            return;
2832        }
2833    }
2834
2835    let mut next_old_event = None;
2836    let mut mergeable_count = 0;
2837    for old_event in events.iter().rev() {
2838        if let Some(next_old_event) = next_old_event
2839            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2840        {
2841            break;
2842        }
2843        mergeable_count += 1;
2844        next_old_event = Some(old_event);
2845    }
2846
2847    if mergeable_count <= 1 {
2848        return;
2849    }
2850
2851    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2852    let oldest_event = events_to_merge.peek().unwrap();
2853    let oldest_snapshot = oldest_event.old_snapshot.clone();
2854    let newest_snapshot = end_snapshot;
2855    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2856
2857    for event in events.range(events.len() - mergeable_count + 1..) {
2858        merged_edit_range =
2859            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2860    }
2861
2862    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2863        &oldest_snapshot,
2864        newest_snapshot,
2865        &merged_edit_range,
2866    ) {
2867        let merged_event = match oldest_event.event.as_ref() {
2868            zeta_prompt::Event::BufferChange {
2869                old_path,
2870                path,
2871                in_open_source_repo,
2872                ..
2873            } => StoredEvent {
2874                event: Arc::new(zeta_prompt::Event::BufferChange {
2875                    old_path: old_path.clone(),
2876                    path: path.clone(),
2877                    diff,
2878                    in_open_source_repo: *in_open_source_repo,
2879                    predicted: events_to_merge.all(|e| {
2880                        matches!(
2881                            e.event.as_ref(),
2882                            zeta_prompt::Event::BufferChange {
2883                                predicted: true,
2884                                ..
2885                            }
2886                        )
2887                    }),
2888                }),
2889                old_snapshot: oldest_snapshot.clone(),
2890                new_snapshot_version: newest_snapshot.version.clone(),
2891                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2892                    ..newest_snapshot.anchor_before(edit_range.end),
2893            },
2894        };
2895        events.truncate(events.len() - mergeable_count);
2896        events.push_back(merged_event);
2897    }
2898}
2899
2900fn merge_anchor_ranges(
2901    left: &Range<Anchor>,
2902    right: &Range<Anchor>,
2903    snapshot: &TextBufferSnapshot,
2904) -> Range<Anchor> {
2905    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2906        left.start
2907    } else {
2908        right.start
2909    };
2910    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2911        left.end
2912    } else {
2913        right.end
2914    };
2915    start..end
2916}
2917
2918#[derive(Error, Debug)]
2919#[error(
2920    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2921)]
2922pub struct ZedUpdateRequiredError {
2923    minimum_version: Version,
2924}
2925
2926#[derive(Debug, Clone, Copy)]
2927pub enum DataCollectionChoice {
2928    NotAnswered,
2929    Enabled,
2930    Disabled,
2931}
2932
2933impl DataCollectionChoice {
2934    pub fn is_enabled(self, cx: &App) -> bool {
2935        if cx.is_staff() {
2936            return true;
2937        }
2938        match self {
2939            Self::Enabled => true,
2940            Self::NotAnswered | Self::Disabled => false,
2941        }
2942    }
2943
2944    #[must_use]
2945    pub fn toggle(&self) -> DataCollectionChoice {
2946        match self {
2947            Self::Enabled => Self::Disabled,
2948            Self::Disabled => Self::Enabled,
2949            Self::NotAnswered => Self::Enabled,
2950        }
2951    }
2952}
2953
2954impl From<bool> for DataCollectionChoice {
2955    fn from(value: bool) -> Self {
2956        match value {
2957            true => DataCollectionChoice::Enabled,
2958            false => DataCollectionChoice::Disabled,
2959        }
2960    }
2961}
2962
2963struct ZedPredictUpsell;
2964
2965impl Dismissable for ZedPredictUpsell {
2966    const KEY: &'static str = "dismissed-edit-predict-upsell";
2967
2968    fn dismissed() -> bool {
2969        // To make this backwards compatible with older versions of Zed, we
2970        // check if the user has seen the previous Edit Prediction Onboarding
2971        // before, by checking the data collection choice which was written to
2972        // the database once the user clicked on "Accept and Enable"
2973        if KEY_VALUE_STORE
2974            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2975            .log_err()
2976            .is_some_and(|s| s.is_some())
2977        {
2978            return true;
2979        }
2980
2981        KEY_VALUE_STORE
2982            .read_kvp(Self::KEY)
2983            .log_err()
2984            .is_some_and(|s| s.is_some())
2985    }
2986}
2987
2988pub fn should_show_upsell_modal() -> bool {
2989    !ZedPredictUpsell::dismissed()
2990}
2991
2992pub fn init(cx: &mut App) {
2993    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2994        workspace.register_action(
2995            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2996                ZedPredictModal::toggle(
2997                    workspace,
2998                    workspace.user_store().clone(),
2999                    workspace.client().clone(),
3000                    window,
3001                    cx,
3002                )
3003            },
3004        );
3005
3006        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3007            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3008                settings
3009                    .project
3010                    .all_languages
3011                    .edit_predictions
3012                    .get_or_insert_default()
3013                    .provider = Some(EditPredictionProvider::None)
3014            });
3015        });
3016        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3017            EditPredictionStore::try_global(cx).and_then(|store| {
3018                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3019            })
3020        }
3021
3022        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3023            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3024                copilot_ui::initiate_sign_in(copilot, window, cx);
3025            }
3026        });
3027        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3028            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3029                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3030            }
3031        });
3032        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3033            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3034                copilot_ui::initiate_sign_out(copilot, window, cx);
3035            }
3036        });
3037    })
3038    .detach();
3039}