edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::example_spec::ExampleSpec;
  79use crate::license_detection::LicenseDetectionWatcher;
  80use crate::mercury::Mercury;
  81use crate::onboarding_modal::ZedPredictModal;
  82pub use crate::prediction::EditPrediction;
  83pub use crate::prediction::EditPredictionId;
  84use crate::prediction::EditPredictionResult;
  85pub use crate::sweep_ai::SweepAi;
  86pub use capture_example::capture_example;
  87pub use language_model::ApiKeyState;
  88pub use telemetry_events::EditPredictionRating;
  89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  90
  91actions!(
  92    edit_prediction,
  93    [
  94        /// Resets the edit prediction onboarding state.
  95        ResetOnboarding,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 110
 111pub struct EditPredictionJumpsFeatureFlag;
 112
 113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 114    const NAME: &'static str = "edit_prediction_jumps";
 115}
 116
 117#[derive(Clone)]
 118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 119
 120impl Global for EditPredictionStoreGlobal {}
 121
 122/// Configuration for using the raw Zeta2 endpoint.
 123/// When set, the client uses the raw endpoint and constructs the prompt itself.
 124/// The version is also used as the Baseten environment name (lowercased).
 125#[derive(Clone)]
 126pub struct Zeta2RawConfig {
 127    pub model_id: Option<String>,
 128    pub environment: Option<String>,
 129    pub format: ZetaFormat,
 130}
 131
 132pub struct EditPredictionStore {
 133    client: Arc<Client>,
 134    user_store: Entity<UserStore>,
 135    llm_token: LlmApiToken,
 136    _llm_token_subscription: Subscription,
 137    _fetch_experiments_task: Task<()>,
 138    projects: HashMap<EntityId, ProjectState>,
 139    update_required: bool,
 140    edit_prediction_model: EditPredictionModel,
 141    zeta2_raw_config: Option<Zeta2RawConfig>,
 142    preferred_experiment: Option<String>,
 143    available_experiments: Vec<String>,
 144    pub sweep_ai: SweepAi,
 145    pub mercury: Mercury,
 146    data_collection_choice: DataCollectionChoice,
 147    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 148    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    #[cfg(test)]
 152    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 153}
 154
 155pub(crate) struct EditPredictionRejectionPayload {
 156    rejection: EditPredictionRejection,
 157    organization_id: Option<OrganizationId>,
 158}
 159
 160#[derive(Copy, Clone, PartialEq, Eq)]
 161pub enum EditPredictionModel {
 162    Zeta,
 163    Fim { format: EditPredictionPromptFormat },
 164    Sweep,
 165    Mercury,
 166}
 167
 168#[derive(Clone)]
 169pub struct EditPredictionModelInput {
 170    project: Entity<Project>,
 171    buffer: Entity<Buffer>,
 172    snapshot: BufferSnapshot,
 173    position: Anchor,
 174    events: Vec<Arc<zeta_prompt::Event>>,
 175    related_files: Vec<RelatedFile>,
 176    recent_paths: VecDeque<ProjectPath>,
 177    trigger: PredictEditsRequestTrigger,
 178    diagnostic_search_range: Range<Point>,
 179    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 180    can_collect_data: bool,
 181    is_open_source: bool,
 182    pub user_actions: Vec<UserActionRecord>,
 183}
 184
 185#[derive(Debug)]
 186pub enum DebugEvent {
 187    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 188    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 189    EditPredictionStarted(EditPredictionStartedDebugEvent),
 190    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 191}
 192
 193#[derive(Debug)]
 194pub struct ContextRetrievalStartedDebugEvent {
 195    pub project_entity_id: EntityId,
 196    pub timestamp: Instant,
 197    pub search_prompt: String,
 198}
 199
 200#[derive(Debug)]
 201pub struct ContextRetrievalFinishedDebugEvent {
 202    pub project_entity_id: EntityId,
 203    pub timestamp: Instant,
 204    pub metadata: Vec<(&'static str, SharedString)>,
 205}
 206
 207#[derive(Debug)]
 208pub struct EditPredictionStartedDebugEvent {
 209    pub buffer: WeakEntity<Buffer>,
 210    pub position: Anchor,
 211    pub prompt: Option<String>,
 212}
 213
 214#[derive(Debug)]
 215pub struct EditPredictionFinishedDebugEvent {
 216    pub buffer: WeakEntity<Buffer>,
 217    pub position: Anchor,
 218    pub model_output: Option<String>,
 219}
 220
 221const USER_ACTION_HISTORY_SIZE: usize = 16;
 222
 223#[derive(Clone, Debug)]
 224pub struct UserActionRecord {
 225    pub action_type: UserActionType,
 226    pub buffer_id: EntityId,
 227    pub line_number: u32,
 228    pub offset: usize,
 229    pub timestamp_epoch_ms: u64,
 230}
 231
 232#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 233pub enum UserActionType {
 234    InsertChar,
 235    InsertSelection,
 236    DeleteChar,
 237    DeleteSelection,
 238    CursorMovement,
 239}
 240
 241/// An event with associated metadata for reconstructing buffer state.
 242#[derive(Clone)]
 243pub struct StoredEvent {
 244    pub event: Arc<zeta_prompt::Event>,
 245    pub old_snapshot: TextBufferSnapshot,
 246    pub edit_range: Range<Anchor>,
 247}
 248
 249impl StoredEvent {
 250    fn can_merge(
 251        &self,
 252        next_old_event: &&&StoredEvent,
 253        new_snapshot: &TextBufferSnapshot,
 254        last_edit_range: &Range<Anchor>,
 255    ) -> bool {
 256        // Events must be for the same buffer
 257        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 258            return false;
 259        }
 260        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 261            return false;
 262        }
 263
 264        let a_is_predicted = matches!(
 265            self.event.as_ref(),
 266            zeta_prompt::Event::BufferChange {
 267                predicted: true,
 268                ..
 269            }
 270        );
 271        let b_is_predicted = matches!(
 272            next_old_event.event.as_ref(),
 273            zeta_prompt::Event::BufferChange {
 274                predicted: true,
 275                ..
 276            }
 277        );
 278
 279        // If events come from the same source (both predicted or both manual) then
 280        // we would have coalesced them already.
 281        if a_is_predicted == b_is_predicted {
 282            return false;
 283        }
 284
 285        let left_range = self.edit_range.to_point(new_snapshot);
 286        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 287        let latest_range = last_edit_range.to_point(&new_snapshot);
 288
 289        // Events near to the latest edit are not merged if their sources differ.
 290        if lines_between_ranges(&left_range, &latest_range)
 291            .min(lines_between_ranges(&right_range, &latest_range))
 292            <= CHANGE_GROUPING_LINE_SPAN
 293        {
 294            return false;
 295        }
 296
 297        // Events that are distant from each other are not merged.
 298        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 299            return false;
 300        }
 301
 302        true
 303    }
 304}
 305
 306fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 307    if left.start > right.end {
 308        return left.start.row - right.end.row;
 309    }
 310    if right.start > left.end {
 311        return right.start.row - left.end.row;
 312    }
 313    0
 314}
 315
 316struct ProjectState {
 317    events: VecDeque<StoredEvent>,
 318    last_event: Option<LastEvent>,
 319    recent_paths: VecDeque<ProjectPath>,
 320    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 321    current_prediction: Option<CurrentEditPrediction>,
 322    next_pending_prediction_id: usize,
 323    pending_predictions: ArrayVec<PendingPrediction, 2>,
 324    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 325    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 326    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 327    cancelled_predictions: HashSet<usize>,
 328    context: Entity<RelatedExcerptStore>,
 329    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 330    user_actions: VecDeque<UserActionRecord>,
 331    _subscriptions: [gpui::Subscription; 2],
 332    copilot: Option<Entity<Copilot>>,
 333}
 334
 335impl ProjectState {
 336    fn record_user_action(&mut self, action: UserActionRecord) {
 337        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 338            self.user_actions.pop_front();
 339        }
 340        self.user_actions.push_back(action);
 341    }
 342
 343    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 344        self.events
 345            .iter()
 346            .cloned()
 347            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 348                let (one, two) = event.split_by_pause();
 349                let one = one.finalize(&self.license_detection_watchers, cx);
 350                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 351                one.into_iter().chain(two)
 352            }))
 353            .collect()
 354    }
 355
 356    fn cancel_pending_prediction(
 357        &mut self,
 358        pending_prediction: PendingPrediction,
 359        cx: &mut Context<EditPredictionStore>,
 360    ) {
 361        self.cancelled_predictions.insert(pending_prediction.id);
 362
 363        if pending_prediction.drop_on_cancel {
 364            drop(pending_prediction.task);
 365        } else {
 366            cx.spawn(async move |this, cx| {
 367                let Some(prediction_id) = pending_prediction.task.await else {
 368                    return;
 369                };
 370
 371                this.update(cx, |this, cx| {
 372                    this.reject_prediction(
 373                        prediction_id,
 374                        EditPredictionRejectReason::Canceled,
 375                        false,
 376                        None,
 377                        cx,
 378                    );
 379                })
 380                .ok();
 381            })
 382            .detach()
 383        }
 384    }
 385
 386    fn active_buffer(
 387        &self,
 388        project: &Entity<Project>,
 389        cx: &App,
 390    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 391        let project = project.read(cx);
 392        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 393        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 394        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 395        Some((active_buffer, registered_buffer.last_position))
 396    }
 397}
 398
 399#[derive(Debug, Clone)]
 400struct CurrentEditPrediction {
 401    pub requested_by: PredictionRequestedBy,
 402    pub prediction: EditPrediction,
 403    pub was_shown: bool,
 404    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 405}
 406
 407impl CurrentEditPrediction {
 408    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 409        let Some(new_edits) = self
 410            .prediction
 411            .interpolate(&self.prediction.buffer.read(cx))
 412        else {
 413            return false;
 414        };
 415
 416        if self.prediction.buffer != old_prediction.prediction.buffer {
 417            return true;
 418        }
 419
 420        let Some(old_edits) = old_prediction
 421            .prediction
 422            .interpolate(&old_prediction.prediction.buffer.read(cx))
 423        else {
 424            return true;
 425        };
 426
 427        let requested_by_buffer_id = self.requested_by.buffer_id();
 428
 429        // This reduces the occurrence of UI thrash from replacing edits
 430        //
 431        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 432        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 433            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 434            && old_edits.len() == 1
 435            && new_edits.len() == 1
 436        {
 437            let (old_range, old_text) = &old_edits[0];
 438            let (new_range, new_text) = &new_edits[0];
 439            new_range == old_range && new_text.starts_with(old_text.as_ref())
 440        } else {
 441            true
 442        }
 443    }
 444}
 445
 446#[derive(Debug, Clone)]
 447enum PredictionRequestedBy {
 448    DiagnosticsUpdate,
 449    Buffer(EntityId),
 450}
 451
 452impl PredictionRequestedBy {
 453    pub fn buffer_id(&self) -> Option<EntityId> {
 454        match self {
 455            PredictionRequestedBy::DiagnosticsUpdate => None,
 456            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 457        }
 458    }
 459}
 460
 461const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 462
 463#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 464pub enum DiagnosticSearchScope {
 465    Local,
 466    Global,
 467}
 468
 469#[derive(Debug)]
 470struct PendingPrediction {
 471    id: usize,
 472    task: Task<Option<EditPredictionId>>,
 473    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 474    /// If false, the task is awaited to completion so rejection can be reported.
 475    drop_on_cancel: bool,
 476}
 477
 478/// A prediction from the perspective of a buffer.
 479#[derive(Debug)]
 480enum BufferEditPrediction<'a> {
 481    Local { prediction: &'a EditPrediction },
 482    Jump { prediction: &'a EditPrediction },
 483}
 484
 485#[cfg(test)]
 486impl std::ops::Deref for BufferEditPrediction<'_> {
 487    type Target = EditPrediction;
 488
 489    fn deref(&self) -> &Self::Target {
 490        match self {
 491            BufferEditPrediction::Local { prediction } => prediction,
 492            BufferEditPrediction::Jump { prediction } => prediction,
 493        }
 494    }
 495}
 496
 497#[derive(Clone)]
 498struct PendingSettledPrediction {
 499    request_id: EditPredictionId,
 500    editable_anchor_range: Range<Anchor>,
 501    example: Option<ExampleSpec>,
 502    enqueued_at: Instant,
 503    last_edit_at: Instant,
 504}
 505
 506struct RegisteredBuffer {
 507    file: Option<Arc<dyn File>>,
 508    snapshot: TextBufferSnapshot,
 509    pending_predictions: Vec<PendingSettledPrediction>,
 510    last_position: Option<Anchor>,
 511    _subscriptions: [gpui::Subscription; 2],
 512}
 513
 514#[derive(Clone)]
 515struct LastEvent {
 516    old_snapshot: TextBufferSnapshot,
 517    new_snapshot: TextBufferSnapshot,
 518    old_file: Option<Arc<dyn File>>,
 519    new_file: Option<Arc<dyn File>>,
 520    edit_range: Option<Range<Anchor>>,
 521    predicted: bool,
 522    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 523    last_edit_time: Option<Instant>,
 524}
 525
 526impl LastEvent {
 527    pub fn finalize(
 528        &self,
 529        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 530        cx: &App,
 531    ) -> Option<StoredEvent> {
 532        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 533        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 534
 535        let in_open_source_repo =
 536            [self.new_file.as_ref(), self.old_file.as_ref()]
 537                .iter()
 538                .all(|file| {
 539                    file.is_some_and(|file| {
 540                        license_detection_watchers
 541                            .get(&file.worktree_id(cx))
 542                            .is_some_and(|watcher| watcher.is_project_open_source())
 543                    })
 544                });
 545
 546        let (diff, edit_range) =
 547            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 548
 549        if path == old_path && diff.is_empty() {
 550            None
 551        } else {
 552            Some(StoredEvent {
 553                event: Arc::new(zeta_prompt::Event::BufferChange {
 554                    old_path,
 555                    path,
 556                    diff,
 557                    in_open_source_repo,
 558                    predicted: self.predicted,
 559                }),
 560                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 561                    ..self.new_snapshot.anchor_before(edit_range.end),
 562                old_snapshot: self.old_snapshot.clone(),
 563            })
 564        }
 565    }
 566
 567    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 568        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 569            return (self.clone(), None);
 570        };
 571
 572        let before = LastEvent {
 573            old_snapshot: self.old_snapshot.clone(),
 574            new_snapshot: boundary_snapshot.clone(),
 575            old_file: self.old_file.clone(),
 576            new_file: self.new_file.clone(),
 577            edit_range: None,
 578            predicted: self.predicted,
 579            snapshot_after_last_editing_pause: None,
 580            last_edit_time: self.last_edit_time,
 581        };
 582
 583        let after = LastEvent {
 584            old_snapshot: boundary_snapshot.clone(),
 585            new_snapshot: self.new_snapshot.clone(),
 586            old_file: self.old_file.clone(),
 587            new_file: self.new_file.clone(),
 588            edit_range: None,
 589            predicted: self.predicted,
 590            snapshot_after_last_editing_pause: None,
 591            last_edit_time: self.last_edit_time,
 592        };
 593
 594        (before, Some(after))
 595    }
 596}
 597
 598pub(crate) fn compute_diff_between_snapshots(
 599    old_snapshot: &TextBufferSnapshot,
 600    new_snapshot: &TextBufferSnapshot,
 601) -> Option<(String, Range<Point>)> {
 602    let edits: Vec<Edit<usize>> = new_snapshot
 603        .edits_since::<usize>(&old_snapshot.version)
 604        .collect();
 605
 606    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 607
 608    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 609    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 610    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 611    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 612
 613    const CONTEXT_LINES: u32 = 3;
 614
 615    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 616    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 617    let old_context_end_row =
 618        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 619    let new_context_end_row =
 620        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 621
 622    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 623    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 624    let old_end_line_offset = old_snapshot
 625        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 626    let new_end_line_offset = new_snapshot
 627        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 628    let old_edit_range = old_start_line_offset..old_end_line_offset;
 629    let new_edit_range = new_start_line_offset..new_end_line_offset;
 630
 631    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 632    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 633
 634    let diff = language::unified_diff_with_offsets(
 635        &old_region_text,
 636        &new_region_text,
 637        old_context_start_row,
 638        new_context_start_row,
 639    );
 640
 641    Some((diff, new_start_point..new_end_point))
 642}
 643
 644fn buffer_path_with_id_fallback(
 645    file: Option<&Arc<dyn File>>,
 646    snapshot: &TextBufferSnapshot,
 647    cx: &App,
 648) -> Arc<Path> {
 649    if let Some(file) = file {
 650        file.full_path(cx).into()
 651    } else {
 652        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 653    }
 654}
 655
 656impl EditPredictionStore {
 657    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 658        cx.try_global::<EditPredictionStoreGlobal>()
 659            .map(|global| global.0.clone())
 660    }
 661
 662    pub fn global(
 663        client: &Arc<Client>,
 664        user_store: &Entity<UserStore>,
 665        cx: &mut App,
 666    ) -> Entity<Self> {
 667        cx.try_global::<EditPredictionStoreGlobal>()
 668            .map(|global| global.0.clone())
 669            .unwrap_or_else(|| {
 670                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 671                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 672                ep_store
 673            })
 674    }
 675
 676    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 677        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 678        let data_collection_choice = Self::load_data_collection_choice();
 679
 680        let llm_token = LlmApiToken::default();
 681
 682        let (reject_tx, reject_rx) = mpsc::unbounded();
 683        cx.background_spawn({
 684            let client = client.clone();
 685            let llm_token = llm_token.clone();
 686            let app_version = AppVersion::global(cx);
 687            let background_executor = cx.background_executor().clone();
 688            async move {
 689                Self::handle_rejected_predictions(
 690                    reject_rx,
 691                    client,
 692                    llm_token,
 693                    app_version,
 694                    background_executor,
 695                )
 696                .await
 697            }
 698        })
 699        .detach();
 700
 701        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 702        cx.spawn(async move |this, cx| {
 703            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 704        })
 705        .detach();
 706
 707        let mut current_user = user_store.read(cx).watch_current_user();
 708        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 709            while current_user.borrow().is_none() {
 710                current_user.next().await;
 711            }
 712            this.update(cx, |this, cx| {
 713                this.refresh_available_experiments(cx);
 714            })
 715            .log_err();
 716        });
 717
 718        let this = Self {
 719            projects: HashMap::default(),
 720            client,
 721            user_store,
 722            llm_token,
 723            _fetch_experiments_task: fetch_experiments_task,
 724            _llm_token_subscription: cx.subscribe(
 725                &refresh_llm_token_listener,
 726                |this, _listener, _event, cx| {
 727                    let client = this.client.clone();
 728                    let llm_token = this.llm_token.clone();
 729                    let organization_id = this
 730                        .user_store
 731                        .read(cx)
 732                        .current_organization()
 733                        .map(|organization| organization.id.clone());
 734                    cx.spawn(async move |_this, _cx| {
 735                        llm_token.refresh(&client, organization_id).await?;
 736                        anyhow::Ok(())
 737                    })
 738                    .detach_and_log_err(cx);
 739                },
 740            ),
 741            update_required: false,
 742            edit_prediction_model: EditPredictionModel::Zeta,
 743            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 744            preferred_experiment: None,
 745            available_experiments: Vec::new(),
 746            sweep_ai: SweepAi::new(cx),
 747            mercury: Mercury::new(cx),
 748
 749            data_collection_choice,
 750            reject_predictions_tx: reject_tx,
 751            settled_predictions_tx,
 752            rated_predictions: Default::default(),
 753            shown_predictions: Default::default(),
 754            #[cfg(test)]
 755            settled_event_callback: None,
 756        };
 757
 758        this
 759    }
 760
 761    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 762        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 763        let format = ZetaFormat::parse(&version_str).ok()?;
 764        let model_id = env::var("ZED_ZETA_MODEL").ok();
 765        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 766        Some(Zeta2RawConfig {
 767            model_id,
 768            environment,
 769            format,
 770        })
 771    }
 772
 773    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 774        self.edit_prediction_model = model;
 775    }
 776
 777    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 778        self.zeta2_raw_config = Some(config);
 779    }
 780
 781    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 782        self.zeta2_raw_config.as_ref()
 783    }
 784
 785    pub fn preferred_experiment(&self) -> Option<&str> {
 786        self.preferred_experiment.as_deref()
 787    }
 788
 789    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 790        self.preferred_experiment = experiment;
 791    }
 792
 793    pub fn available_experiments(&self) -> &[String] {
 794        &self.available_experiments
 795    }
 796
 797    pub fn active_experiment(&self) -> Option<&str> {
 798        self.preferred_experiment.as_deref().or_else(|| {
 799            self.shown_predictions
 800                .iter()
 801                .find_map(|p| p.model_version.as_ref())
 802                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 803        })
 804    }
 805
 806    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 807        let client = self.client.clone();
 808        let llm_token = self.llm_token.clone();
 809        let app_version = AppVersion::global(cx);
 810        let organization_id = self
 811            .user_store
 812            .read(cx)
 813            .current_organization()
 814            .map(|organization| organization.id.clone());
 815
 816        cx.spawn(async move |this, cx| {
 817            let experiments = cx
 818                .background_spawn(async move {
 819                    let http_client = client.http_client();
 820                    let token = llm_token.acquire(&client, organization_id).await?;
 821                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 822                    let request = http_client::Request::builder()
 823                        .method(Method::GET)
 824                        .uri(url.as_ref())
 825                        .header("Authorization", format!("Bearer {}", token))
 826                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 827                        .body(Default::default())?;
 828                    let mut response = http_client.send(request).await?;
 829                    if response.status().is_success() {
 830                        let mut body = Vec::new();
 831                        response.body_mut().read_to_end(&mut body).await?;
 832                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 833                        Ok(experiments)
 834                    } else {
 835                        let mut body = String::new();
 836                        response.body_mut().read_to_string(&mut body).await?;
 837                        anyhow::bail!(
 838                            "Failed to fetch experiments: {:?}\nBody: {}",
 839                            response.status(),
 840                            body
 841                        );
 842                    }
 843                })
 844                .await?;
 845            this.update(cx, |this, cx| {
 846                this.available_experiments = experiments;
 847                cx.notify();
 848            })?;
 849            anyhow::Ok(())
 850        })
 851        .detach_and_log_err(cx);
 852    }
 853
 854    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 855        use ui::IconName;
 856        match self.edit_prediction_model {
 857            EditPredictionModel::Sweep => {
 858                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 859                    .with_disabled(IconName::SweepAiDisabled)
 860                    .with_up(IconName::SweepAiUp)
 861                    .with_down(IconName::SweepAiDown)
 862                    .with_error(IconName::SweepAiError)
 863            }
 864            EditPredictionModel::Mercury => {
 865                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 866            }
 867            EditPredictionModel::Zeta => {
 868                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 869                    .with_disabled(IconName::ZedPredictDisabled)
 870                    .with_up(IconName::ZedPredictUp)
 871                    .with_down(IconName::ZedPredictDown)
 872                    .with_error(IconName::ZedPredictError)
 873            }
 874            EditPredictionModel::Fim { .. } => {
 875                let settings = &all_language_settings(None, cx).edit_predictions;
 876                match settings.provider {
 877                    EditPredictionProvider::Ollama => {
 878                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 879                    }
 880                    _ => {
 881                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 882                    }
 883                }
 884            }
 885        }
 886    }
 887
 888    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 889        self.sweep_ai.api_token.read(cx).has_key()
 890    }
 891
 892    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 893        self.mercury.api_token.read(cx).has_key()
 894    }
 895
 896    pub fn clear_history(&mut self) {
 897        for project_state in self.projects.values_mut() {
 898            project_state.events.clear();
 899            project_state.last_event.take();
 900        }
 901    }
 902
 903    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 904        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 905            project_state.events.clear();
 906            project_state.last_event.take();
 907        }
 908    }
 909
 910    pub fn edit_history_for_project(
 911        &self,
 912        project: &Entity<Project>,
 913        cx: &App,
 914    ) -> Vec<StoredEvent> {
 915        self.projects
 916            .get(&project.entity_id())
 917            .map(|project_state| project_state.events(cx))
 918            .unwrap_or_default()
 919    }
 920
 921    pub fn context_for_project<'a>(
 922        &'a self,
 923        project: &Entity<Project>,
 924        cx: &'a mut App,
 925    ) -> Vec<RelatedFile> {
 926        self.projects
 927            .get(&project.entity_id())
 928            .map(|project_state| {
 929                project_state.context.update(cx, |context, cx| {
 930                    context
 931                        .related_files_with_buffers(cx)
 932                        .map(|(mut related_file, buffer)| {
 933                            related_file.in_open_source_repo = buffer
 934                                .read(cx)
 935                                .file()
 936                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 937                            related_file
 938                        })
 939                        .collect()
 940                })
 941            })
 942            .unwrap_or_default()
 943    }
 944
 945    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 946        self.projects
 947            .get(&project.entity_id())
 948            .and_then(|project| project.copilot.clone())
 949    }
 950
 951    pub fn start_copilot_for_project(
 952        &mut self,
 953        project: &Entity<Project>,
 954        cx: &mut Context<Self>,
 955    ) -> Option<Entity<Copilot>> {
 956        if DisableAiSettings::get(None, cx).disable_ai {
 957            return None;
 958        }
 959        let state = self.get_or_init_project(project, cx);
 960
 961        if state.copilot.is_some() {
 962            return state.copilot.clone();
 963        }
 964        let _project = project.clone();
 965        let project = project.read(cx);
 966
 967        let node = project.node_runtime().cloned();
 968        if let Some(node) = node {
 969            let next_id = project.languages().next_language_server_id();
 970            let fs = project.fs().clone();
 971
 972            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 973            state.copilot = Some(copilot.clone());
 974            Some(copilot)
 975        } else {
 976            None
 977        }
 978    }
 979
 980    pub fn context_for_project_with_buffers<'a>(
 981        &'a self,
 982        project: &Entity<Project>,
 983        cx: &'a mut App,
 984    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 985        self.projects
 986            .get(&project.entity_id())
 987            .map(|project| {
 988                project.context.update(cx, |context, cx| {
 989                    context.related_files_with_buffers(cx).collect()
 990                })
 991            })
 992            .unwrap_or_default()
 993    }
 994
 995    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 996        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
 997            self.user_store.read(cx).edit_prediction_usage()
 998        } else {
 999            None
1000        }
1001    }
1002
1003    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1004        self.get_or_init_project(project, cx);
1005    }
1006
1007    pub fn register_buffer(
1008        &mut self,
1009        buffer: &Entity<Buffer>,
1010        project: &Entity<Project>,
1011        cx: &mut Context<Self>,
1012    ) {
1013        let project_state = self.get_or_init_project(project, cx);
1014        Self::register_buffer_impl(project_state, buffer, project, cx);
1015    }
1016
1017    fn get_or_init_project(
1018        &mut self,
1019        project: &Entity<Project>,
1020        cx: &mut Context<Self>,
1021    ) -> &mut ProjectState {
1022        let entity_id = project.entity_id();
1023        self.projects
1024            .entry(entity_id)
1025            .or_insert_with(|| ProjectState {
1026                context: {
1027                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1028                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1029                        this.handle_excerpt_store_event(entity_id, event);
1030                    })
1031                    .detach();
1032                    related_excerpt_store
1033                },
1034                events: VecDeque::new(),
1035                last_event: None,
1036                recent_paths: VecDeque::new(),
1037                debug_tx: None,
1038                registered_buffers: HashMap::default(),
1039                current_prediction: None,
1040                cancelled_predictions: HashSet::default(),
1041                pending_predictions: ArrayVec::new(),
1042                next_pending_prediction_id: 0,
1043                last_edit_prediction_refresh: None,
1044                last_jump_prediction_refresh: None,
1045                license_detection_watchers: HashMap::default(),
1046                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1047                _subscriptions: [
1048                    cx.subscribe(&project, Self::handle_project_event),
1049                    cx.observe_release(&project, move |this, _, cx| {
1050                        this.projects.remove(&entity_id);
1051                        cx.notify();
1052                    }),
1053                ],
1054                copilot: None,
1055            })
1056    }
1057
1058    pub fn remove_project(&mut self, project: &Entity<Project>) {
1059        self.projects.remove(&project.entity_id());
1060    }
1061
1062    fn handle_excerpt_store_event(
1063        &mut self,
1064        project_entity_id: EntityId,
1065        event: &RelatedExcerptStoreEvent,
1066    ) {
1067        if let Some(project_state) = self.projects.get(&project_entity_id) {
1068            if let Some(debug_tx) = project_state.debug_tx.clone() {
1069                match event {
1070                    RelatedExcerptStoreEvent::StartedRefresh => {
1071                        debug_tx
1072                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1073                                ContextRetrievalStartedDebugEvent {
1074                                    project_entity_id: project_entity_id,
1075                                    timestamp: Instant::now(),
1076                                    search_prompt: String::new(),
1077                                },
1078                            ))
1079                            .ok();
1080                    }
1081                    RelatedExcerptStoreEvent::FinishedRefresh {
1082                        cache_hit_count,
1083                        cache_miss_count,
1084                        mean_definition_latency,
1085                        max_definition_latency,
1086                    } => {
1087                        debug_tx
1088                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1089                                ContextRetrievalFinishedDebugEvent {
1090                                    project_entity_id: project_entity_id,
1091                                    timestamp: Instant::now(),
1092                                    metadata: vec![
1093                                        (
1094                                            "Cache Hits",
1095                                            format!(
1096                                                "{}/{}",
1097                                                cache_hit_count,
1098                                                cache_hit_count + cache_miss_count
1099                                            )
1100                                            .into(),
1101                                        ),
1102                                        (
1103                                            "Max LSP Time",
1104                                            format!("{} ms", max_definition_latency.as_millis())
1105                                                .into(),
1106                                        ),
1107                                        (
1108                                            "Mean LSP Time",
1109                                            format!("{} ms", mean_definition_latency.as_millis())
1110                                                .into(),
1111                                        ),
1112                                    ],
1113                                },
1114                            ))
1115                            .ok();
1116                    }
1117                }
1118            }
1119        }
1120    }
1121
1122    pub fn debug_info(
1123        &mut self,
1124        project: &Entity<Project>,
1125        cx: &mut Context<Self>,
1126    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1127        let project_state = self.get_or_init_project(project, cx);
1128        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1129        project_state.debug_tx = Some(debug_watch_tx);
1130        debug_watch_rx
1131    }
1132
1133    fn handle_project_event(
1134        &mut self,
1135        project: Entity<Project>,
1136        event: &project::Event,
1137        cx: &mut Context<Self>,
1138    ) {
1139        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1140            return;
1141        }
1142        // TODO [zeta2] init with recent paths
1143        match event {
1144            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1145                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1146                    return;
1147                };
1148                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1149                if let Some(path) = path {
1150                    if let Some(ix) = project_state
1151                        .recent_paths
1152                        .iter()
1153                        .position(|probe| probe == &path)
1154                    {
1155                        project_state.recent_paths.remove(ix);
1156                    }
1157                    project_state.recent_paths.push_front(path);
1158                }
1159            }
1160            project::Event::DiagnosticsUpdated { .. } => {
1161                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1162                    self.refresh_prediction_from_diagnostics(
1163                        project,
1164                        DiagnosticSearchScope::Global,
1165                        cx,
1166                    );
1167                }
1168            }
1169            _ => (),
1170        }
1171    }
1172
1173    fn register_buffer_impl<'a>(
1174        project_state: &'a mut ProjectState,
1175        buffer: &Entity<Buffer>,
1176        project: &Entity<Project>,
1177        cx: &mut Context<Self>,
1178    ) -> &'a mut RegisteredBuffer {
1179        let buffer_id = buffer.entity_id();
1180
1181        if let Some(file) = buffer.read(cx).file() {
1182            let worktree_id = file.worktree_id(cx);
1183            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1184                project_state
1185                    .license_detection_watchers
1186                    .entry(worktree_id)
1187                    .or_insert_with(|| {
1188                        let project_entity_id = project.entity_id();
1189                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1190                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1191                            else {
1192                                return;
1193                            };
1194                            project_state
1195                                .license_detection_watchers
1196                                .remove(&worktree_id);
1197                        })
1198                        .detach();
1199                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1200                    });
1201            }
1202        }
1203
1204        match project_state.registered_buffers.entry(buffer_id) {
1205            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1206            hash_map::Entry::Vacant(entry) => {
1207                let buf = buffer.read(cx);
1208                let snapshot = buf.text_snapshot();
1209                let file = buf.file().cloned();
1210                let project_entity_id = project.entity_id();
1211                entry.insert(RegisteredBuffer {
1212                    snapshot,
1213                    file,
1214                    last_position: None,
1215                    pending_predictions: Vec::new(),
1216                    _subscriptions: [
1217                        cx.subscribe(buffer, {
1218                            let project = project.downgrade();
1219                            move |this, buffer, event, cx| {
1220                                if let language::BufferEvent::Edited = event
1221                                    && let Some(project) = project.upgrade()
1222                                {
1223                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1224                                }
1225                            }
1226                        }),
1227                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1228                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1229                            else {
1230                                return;
1231                            };
1232                            project_state.registered_buffers.remove(&buffer_id);
1233                        }),
1234                    ],
1235                })
1236            }
1237        }
1238    }
1239
1240    fn report_changes_for_buffer(
1241        &mut self,
1242        buffer: &Entity<Buffer>,
1243        project: &Entity<Project>,
1244        is_predicted: bool,
1245        cx: &mut Context<Self>,
1246    ) {
1247        let project_state = self.get_or_init_project(project, cx);
1248        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1249
1250        let buf = buffer.read(cx);
1251        let new_file = buf.file().cloned();
1252        let new_snapshot = buf.text_snapshot();
1253        if new_snapshot.version == registered_buffer.snapshot.version {
1254            return;
1255        }
1256
1257        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1258        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1259        let mut num_edits = 0usize;
1260        let mut total_deleted = 0usize;
1261        let mut total_inserted = 0usize;
1262        let mut edit_range: Option<Range<Anchor>> = None;
1263        let mut last_offset: Option<usize> = None;
1264        let now = cx.background_executor().now();
1265
1266        for (edit, anchor_range) in
1267            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1268        {
1269            num_edits += 1;
1270            total_deleted += edit.old.len();
1271            total_inserted += edit.new.len();
1272            edit_range = Some(match edit_range {
1273                None => anchor_range,
1274                Some(acc) => acc.start..anchor_range.end,
1275            });
1276            last_offset = Some(edit.new.end);
1277        }
1278
1279        let Some(edit_range) = edit_range else {
1280            return;
1281        };
1282
1283        for pending_prediction in &mut registered_buffer.pending_predictions {
1284            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1285                pending_prediction.last_edit_at = now;
1286            }
1287        }
1288
1289        let action_type = match (total_deleted, total_inserted, num_edits) {
1290            (0, ins, n) if ins == n => UserActionType::InsertChar,
1291            (0, _, _) => UserActionType::InsertSelection,
1292            (del, 0, n) if del == n => UserActionType::DeleteChar,
1293            (_, 0, _) => UserActionType::DeleteSelection,
1294            (_, ins, n) if ins == n => UserActionType::InsertChar,
1295            (_, _, _) => UserActionType::InsertSelection,
1296        };
1297
1298        if let Some(offset) = last_offset {
1299            let point = new_snapshot.offset_to_point(offset);
1300            let timestamp_epoch_ms = SystemTime::now()
1301                .duration_since(UNIX_EPOCH)
1302                .map(|d| d.as_millis() as u64)
1303                .unwrap_or(0);
1304            project_state.record_user_action(UserActionRecord {
1305                action_type,
1306                buffer_id: buffer.entity_id(),
1307                line_number: point.row,
1308                offset,
1309                timestamp_epoch_ms,
1310            });
1311        }
1312
1313        let events = &mut project_state.events;
1314
1315        if let Some(last_event) = project_state.last_event.as_mut() {
1316            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1317                == last_event.new_snapshot.remote_id()
1318                && old_snapshot.version == last_event.new_snapshot.version;
1319
1320            let prediction_source_changed = is_predicted != last_event.predicted;
1321
1322            let should_coalesce = is_next_snapshot_of_same_buffer
1323                && !prediction_source_changed
1324                && last_event
1325                    .edit_range
1326                    .as_ref()
1327                    .is_some_and(|last_edit_range| {
1328                        lines_between_ranges(
1329                            &edit_range.to_point(&new_snapshot),
1330                            &last_edit_range.to_point(&new_snapshot),
1331                        ) <= CHANGE_GROUPING_LINE_SPAN
1332                    });
1333
1334            if should_coalesce {
1335                let pause_elapsed = last_event
1336                    .last_edit_time
1337                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1338                    .unwrap_or(false);
1339                if pause_elapsed {
1340                    last_event.snapshot_after_last_editing_pause =
1341                        Some(last_event.new_snapshot.clone());
1342                }
1343
1344                last_event.edit_range = Some(edit_range);
1345                last_event.new_snapshot = new_snapshot;
1346                last_event.last_edit_time = Some(now);
1347                return;
1348            }
1349        }
1350
1351        if let Some(event) = project_state.last_event.take() {
1352            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1353                if events.len() + 1 >= EVENT_COUNT_MAX {
1354                    events.pop_front();
1355                }
1356                events.push_back(event);
1357            }
1358        }
1359
1360        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1361
1362        project_state.last_event = Some(LastEvent {
1363            old_file,
1364            new_file,
1365            old_snapshot,
1366            new_snapshot,
1367            edit_range: Some(edit_range),
1368            predicted: is_predicted,
1369            snapshot_after_last_editing_pause: None,
1370            last_edit_time: Some(now),
1371        });
1372    }
1373
1374    fn prediction_at(
1375        &mut self,
1376        buffer: &Entity<Buffer>,
1377        position: Option<language::Anchor>,
1378        project: &Entity<Project>,
1379        cx: &App,
1380    ) -> Option<BufferEditPrediction<'_>> {
1381        let project_state = self.projects.get_mut(&project.entity_id())?;
1382        if let Some(position) = position
1383            && let Some(buffer) = project_state
1384                .registered_buffers
1385                .get_mut(&buffer.entity_id())
1386        {
1387            buffer.last_position = Some(position);
1388        }
1389
1390        let CurrentEditPrediction {
1391            requested_by,
1392            prediction,
1393            ..
1394        } = project_state.current_prediction.as_ref()?;
1395
1396        if prediction.targets_buffer(buffer.read(cx)) {
1397            Some(BufferEditPrediction::Local { prediction })
1398        } else {
1399            let show_jump = match requested_by {
1400                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1401                    requested_by_buffer_id == &buffer.entity_id()
1402                }
1403                PredictionRequestedBy::DiagnosticsUpdate => true,
1404            };
1405
1406            if show_jump {
1407                Some(BufferEditPrediction::Jump { prediction })
1408            } else {
1409                None
1410            }
1411        }
1412    }
1413
1414    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1415        let Some(current_prediction) = self
1416            .projects
1417            .get_mut(&project.entity_id())
1418            .and_then(|project_state| project_state.current_prediction.take())
1419        else {
1420            return;
1421        };
1422
1423        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1424
1425        // can't hold &mut project_state ref across report_changes_for_buffer_call
1426        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1427            return;
1428        };
1429
1430        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1431            project_state.cancel_pending_prediction(pending_prediction, cx);
1432        }
1433
1434        match self.edit_prediction_model {
1435            EditPredictionModel::Sweep => {
1436                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1437            }
1438            EditPredictionModel::Mercury => {
1439                mercury::edit_prediction_accepted(
1440                    current_prediction.prediction.id,
1441                    self.client.http_client(),
1442                    cx,
1443                );
1444            }
1445            EditPredictionModel::Zeta => {
1446                let is_cloud = !matches!(
1447                    all_language_settings(None, cx).edit_predictions.provider,
1448                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1449                );
1450                if is_cloud {
1451                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1452                }
1453            }
1454            EditPredictionModel::Fim { .. } => {}
1455        }
1456    }
1457
1458    async fn handle_rejected_predictions(
1459        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1460        client: Arc<Client>,
1461        llm_token: LlmApiToken,
1462        app_version: Version,
1463        background_executor: BackgroundExecutor,
1464    ) {
1465        let mut rx = std::pin::pin!(rx.peekable());
1466        let mut batched = Vec::new();
1467
1468        while let Some(EditPredictionRejectionPayload {
1469            rejection,
1470            organization_id,
1471        }) = rx.next().await
1472        {
1473            batched.push(rejection);
1474
1475            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1476                select_biased! {
1477                    next = rx.as_mut().peek().fuse() => {
1478                        if next.is_some() {
1479                            continue;
1480                        }
1481                    }
1482                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1483                }
1484            }
1485
1486            let url = client
1487                .http_client()
1488                .build_zed_llm_url("/predict_edits/reject", &[])
1489                .unwrap();
1490
1491            let flush_count = batched
1492                .len()
1493                // in case items have accumulated after failure
1494                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1495            let start = batched.len() - flush_count;
1496
1497            let body = RejectEditPredictionsBodyRef {
1498                rejections: &batched[start..],
1499            };
1500
1501            let result = Self::send_api_request::<()>(
1502                |builder| {
1503                    let req = builder
1504                        .uri(url.as_ref())
1505                        .body(serde_json::to_string(&body)?.into());
1506                    anyhow::Ok(req?)
1507                },
1508                client.clone(),
1509                llm_token.clone(),
1510                organization_id,
1511                app_version.clone(),
1512                true,
1513            )
1514            .await;
1515
1516            if result.log_err().is_some() {
1517                batched.drain(start..);
1518            }
1519        }
1520    }
1521
1522    async fn run_settled_predictions_worker(
1523        this: WeakEntity<Self>,
1524        mut rx: UnboundedReceiver<Instant>,
1525        cx: &mut AsyncApp,
1526    ) {
1527        let mut next_wake_time: Option<Instant> = None;
1528        loop {
1529            let now = cx.background_executor().now();
1530            if let Some(wake_time) = next_wake_time.take() {
1531                cx.background_executor()
1532                    .timer(wake_time.duration_since(now))
1533                    .await;
1534            } else {
1535                let Some(new_enqueue_time) = rx.next().await else {
1536                    break;
1537                };
1538                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1539                while rx.next().now_or_never().flatten().is_some() {}
1540                continue;
1541            }
1542
1543            let Some(this) = this.upgrade() else {
1544                break;
1545            };
1546
1547            let now = cx.background_executor().now();
1548
1549            let mut oldest_edited_at = None;
1550
1551            this.update(cx, |this, _| {
1552                for (_, project_state) in this.projects.iter_mut() {
1553                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1554                        registered_buffer
1555                            .pending_predictions
1556                            .retain_mut(|pending_prediction| {
1557                                let age =
1558                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1559                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1560                                    return false;
1561                                }
1562
1563                                let quiet_for =
1564                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1565                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1566                                    let settled_editable_region = registered_buffer
1567                                        .snapshot
1568                                        .text_for_range(
1569                                            pending_prediction.editable_anchor_range.clone(),
1570                                        )
1571                                        .collect::<String>();
1572
1573                                    #[cfg(test)]
1574                                    if let Some(callback) = &this.settled_event_callback {
1575                                        callback(
1576                                            pending_prediction.request_id.clone(),
1577                                            settled_editable_region.clone(),
1578                                        );
1579                                    }
1580
1581                                    telemetry::event!(
1582                                        EDIT_PREDICTION_SETTLED_EVENT,
1583                                        request_id = pending_prediction.request_id.0.clone(),
1584                                        settled_editable_region,
1585                                        example = pending_prediction.example.take(),
1586                                    );
1587
1588                                    return false;
1589                                }
1590
1591                                if oldest_edited_at
1592                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1593                                {
1594                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1595                                }
1596
1597                                true
1598                            });
1599                    }
1600                }
1601            });
1602
1603            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1604        }
1605    }
1606
1607    pub(crate) fn enqueue_settled_prediction(
1608        &mut self,
1609        request_id: EditPredictionId,
1610        project: &Entity<Project>,
1611        edited_buffer: &Entity<Buffer>,
1612        edited_buffer_snapshot: &BufferSnapshot,
1613        editable_offset_range: Range<usize>,
1614        example: Option<ExampleSpec>,
1615        cx: &mut Context<Self>,
1616    ) {
1617        let this = &mut *self;
1618        let project_state = this.get_or_init_project(project, cx);
1619        if let Some(buffer) = project_state
1620            .registered_buffers
1621            .get_mut(&edited_buffer.entity_id())
1622        {
1623            let now = cx.background_executor().now();
1624            buffer.pending_predictions.push(PendingSettledPrediction {
1625                request_id: request_id,
1626                editable_anchor_range: edited_buffer_snapshot
1627                    .anchor_range_around(editable_offset_range),
1628                example,
1629                enqueued_at: now,
1630                last_edit_at: now,
1631            });
1632            this.settled_predictions_tx.unbounded_send(now).ok();
1633        }
1634    }
1635
1636    fn reject_current_prediction(
1637        &mut self,
1638        reason: EditPredictionRejectReason,
1639        project: &Entity<Project>,
1640        cx: &App,
1641    ) {
1642        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1643            project_state.pending_predictions.clear();
1644            if let Some(prediction) = project_state.current_prediction.take() {
1645                let model_version = prediction.prediction.model_version.clone();
1646                self.reject_prediction(
1647                    prediction.prediction.id,
1648                    reason,
1649                    prediction.was_shown,
1650                    model_version,
1651                    cx,
1652                );
1653            }
1654        };
1655    }
1656
1657    fn did_show_current_prediction(
1658        &mut self,
1659        project: &Entity<Project>,
1660        display_type: edit_prediction_types::SuggestionDisplayType,
1661        cx: &mut Context<Self>,
1662    ) {
1663        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1664            return;
1665        };
1666
1667        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1668            return;
1669        };
1670
1671        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1672        let previous_shown_with = current_prediction.shown_with;
1673
1674        if previous_shown_with.is_none() || !is_jump {
1675            current_prediction.shown_with = Some(display_type);
1676        }
1677
1678        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1679
1680        if is_first_non_jump_show {
1681            current_prediction.was_shown = true;
1682        }
1683
1684        let display_type_changed = previous_shown_with != Some(display_type);
1685
1686        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1687            sweep_ai::edit_prediction_shown(
1688                &self.sweep_ai,
1689                self.client.clone(),
1690                &current_prediction.prediction,
1691                display_type,
1692                cx,
1693            );
1694        }
1695
1696        if is_first_non_jump_show {
1697            self.shown_predictions
1698                .push_front(current_prediction.prediction.clone());
1699            if self.shown_predictions.len() > 50 {
1700                let completion = self.shown_predictions.pop_back().unwrap();
1701                self.rated_predictions.remove(&completion.id);
1702            }
1703        }
1704    }
1705
1706    fn reject_prediction(
1707        &mut self,
1708        prediction_id: EditPredictionId,
1709        reason: EditPredictionRejectReason,
1710        was_shown: bool,
1711        model_version: Option<String>,
1712        cx: &App,
1713    ) {
1714        match self.edit_prediction_model {
1715            EditPredictionModel::Zeta => {
1716                let is_cloud = !matches!(
1717                    all_language_settings(None, cx).edit_predictions.provider,
1718                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1719                );
1720
1721                if is_cloud {
1722                    let organization_id = self
1723                        .user_store
1724                        .read(cx)
1725                        .current_organization()
1726                        .map(|organization| organization.id.clone());
1727
1728                    self.reject_predictions_tx
1729                        .unbounded_send(EditPredictionRejectionPayload {
1730                            rejection: EditPredictionRejection {
1731                                request_id: prediction_id.to_string(),
1732                                reason,
1733                                was_shown,
1734                                model_version,
1735                            },
1736                            organization_id,
1737                        })
1738                        .log_err();
1739                }
1740            }
1741            EditPredictionModel::Mercury => {
1742                mercury::edit_prediction_rejected(
1743                    prediction_id,
1744                    was_shown,
1745                    reason,
1746                    self.client.http_client(),
1747                    cx,
1748                );
1749            }
1750            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1751        }
1752    }
1753
1754    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1755        self.projects
1756            .get(&project.entity_id())
1757            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1758    }
1759
1760    pub fn refresh_prediction_from_buffer(
1761        &mut self,
1762        project: Entity<Project>,
1763        buffer: Entity<Buffer>,
1764        position: language::Anchor,
1765        cx: &mut Context<Self>,
1766    ) {
1767        self.queue_prediction_refresh(
1768            project.clone(),
1769            PredictEditsRequestTrigger::Other,
1770            buffer.entity_id(),
1771            cx,
1772            move |this, cx| {
1773                let Some(request_task) = this
1774                    .update(cx, |this, cx| {
1775                        this.request_prediction(
1776                            &project,
1777                            &buffer,
1778                            position,
1779                            PredictEditsRequestTrigger::Other,
1780                            cx,
1781                        )
1782                    })
1783                    .log_err()
1784                else {
1785                    return Task::ready(anyhow::Ok(None));
1786                };
1787
1788                cx.spawn(async move |_cx| {
1789                    request_task.await.map(|prediction_result| {
1790                        prediction_result.map(|prediction_result| {
1791                            (
1792                                prediction_result,
1793                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1794                            )
1795                        })
1796                    })
1797                })
1798            },
1799        )
1800    }
1801
1802    pub fn refresh_prediction_from_diagnostics(
1803        &mut self,
1804        project: Entity<Project>,
1805        scope: DiagnosticSearchScope,
1806        cx: &mut Context<Self>,
1807    ) {
1808        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1809            return;
1810        }
1811
1812        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1813            return;
1814        };
1815
1816        // Prefer predictions from buffer
1817        if project_state.current_prediction.is_some() {
1818            log::debug!(
1819                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1820            );
1821            return;
1822        }
1823
1824        self.queue_prediction_refresh(
1825            project.clone(),
1826            PredictEditsRequestTrigger::Diagnostics,
1827            project.entity_id(),
1828            cx,
1829            move |this, cx| {
1830                let Some((active_buffer, snapshot, cursor_point)) = this
1831                    .read_with(cx, |this, cx| {
1832                        let project_state = this.projects.get(&project.entity_id())?;
1833                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1834                        let snapshot = buffer.read(cx).snapshot();
1835
1836                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1837                            return None;
1838                        }
1839
1840                        let cursor_point = position
1841                            .map(|pos| pos.to_point(&snapshot))
1842                            .unwrap_or_default();
1843
1844                        Some((buffer, snapshot, cursor_point))
1845                    })
1846                    .log_err()
1847                    .flatten()
1848                else {
1849                    return Task::ready(anyhow::Ok(None));
1850                };
1851
1852                cx.spawn(async move |cx| {
1853                    let diagnostic_search_range = match scope {
1854                        DiagnosticSearchScope::Local => {
1855                            let diagnostic_search_start =
1856                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1857                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1858                            Point::new(diagnostic_search_start, 0)
1859                                ..Point::new(diagnostic_search_end, 0)
1860                        }
1861                        DiagnosticSearchScope::Global => Default::default(),
1862                    };
1863
1864                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1865                        active_buffer,
1866                        &snapshot,
1867                        diagnostic_search_range,
1868                        cursor_point,
1869                        &project,
1870                        cx,
1871                    )
1872                    .await?
1873                    else {
1874                        return anyhow::Ok(None);
1875                    };
1876
1877                    let Some(prediction_result) = this
1878                        .update(cx, |this, cx| {
1879                            this.request_prediction(
1880                                &project,
1881                                &jump_buffer,
1882                                jump_position,
1883                                PredictEditsRequestTrigger::Diagnostics,
1884                                cx,
1885                            )
1886                        })?
1887                        .await?
1888                    else {
1889                        return anyhow::Ok(None);
1890                    };
1891
1892                    this.update(cx, |this, cx| {
1893                        Some((
1894                            if this
1895                                .get_or_init_project(&project, cx)
1896                                .current_prediction
1897                                .is_none()
1898                            {
1899                                prediction_result
1900                            } else {
1901                                EditPredictionResult {
1902                                    id: prediction_result.id,
1903                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1904                                }
1905                            },
1906                            PredictionRequestedBy::DiagnosticsUpdate,
1907                        ))
1908                    })
1909                })
1910            },
1911        );
1912    }
1913
1914    fn predictions_enabled_at(
1915        snapshot: &BufferSnapshot,
1916        position: Option<language::Anchor>,
1917        cx: &App,
1918    ) -> bool {
1919        let file = snapshot.file();
1920        let all_settings = all_language_settings(file, cx);
1921        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1922            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1923        {
1924            return false;
1925        }
1926
1927        if let Some(last_position) = position {
1928            let settings = snapshot.settings_at(last_position, cx);
1929
1930            if !settings.edit_predictions_disabled_in.is_empty()
1931                && let Some(scope) = snapshot.language_scope_at(last_position)
1932                && let Some(scope_name) = scope.override_name()
1933                && settings
1934                    .edit_predictions_disabled_in
1935                    .iter()
1936                    .any(|s| s == scope_name)
1937            {
1938                return false;
1939            }
1940        }
1941
1942        true
1943    }
1944
1945    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1946}
1947
1948fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1949    match provider {
1950        EditPredictionProvider::Zed
1951        | EditPredictionProvider::Sweep
1952        | EditPredictionProvider::Mercury
1953        | EditPredictionProvider::Ollama
1954        | EditPredictionProvider::OpenAiCompatibleApi
1955        | EditPredictionProvider::Experimental(_) => true,
1956        EditPredictionProvider::None
1957        | EditPredictionProvider::Copilot
1958        | EditPredictionProvider::Codestral => false,
1959    }
1960}
1961
1962impl EditPredictionStore {
1963    fn queue_prediction_refresh(
1964        &mut self,
1965        project: Entity<Project>,
1966        request_trigger: PredictEditsRequestTrigger,
1967        throttle_entity: EntityId,
1968        cx: &mut Context<Self>,
1969        do_refresh: impl FnOnce(
1970            WeakEntity<Self>,
1971            &mut AsyncApp,
1972        )
1973            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1974        + 'static,
1975    ) {
1976        fn select_throttle(
1977            project_state: &mut ProjectState,
1978            request_trigger: PredictEditsRequestTrigger,
1979        ) -> &mut Option<(EntityId, Instant)> {
1980            match request_trigger {
1981                PredictEditsRequestTrigger::Diagnostics => {
1982                    &mut project_state.last_jump_prediction_refresh
1983                }
1984                _ => &mut project_state.last_edit_prediction_refresh,
1985            }
1986        }
1987
1988        let (needs_acceptance_tracking, max_pending_predictions) =
1989            match all_language_settings(None, cx).edit_predictions.provider {
1990                EditPredictionProvider::Zed
1991                | EditPredictionProvider::Sweep
1992                | EditPredictionProvider::Mercury
1993                | EditPredictionProvider::Experimental(_) => (true, 2),
1994                EditPredictionProvider::Ollama => (false, 1),
1995                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1996                EditPredictionProvider::None
1997                | EditPredictionProvider::Copilot
1998                | EditPredictionProvider::Codestral => {
1999                    log::error!("queue_prediction_refresh called with non-store provider");
2000                    return;
2001                }
2002            };
2003
2004        let drop_on_cancel = !needs_acceptance_tracking;
2005        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2006        let project_state = self.get_or_init_project(&project, cx);
2007        let pending_prediction_id = project_state.next_pending_prediction_id;
2008        project_state.next_pending_prediction_id += 1;
2009        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2010
2011        let task = cx.spawn(async move |this, cx| {
2012            let throttle_wait = this
2013                .update(cx, |this, cx| {
2014                    let project_state = this.get_or_init_project(&project, cx);
2015                    let throttle = *select_throttle(project_state, request_trigger);
2016
2017                    throttle.and_then(|(last_entity, last_timestamp)| {
2018                        if throttle_entity != last_entity {
2019                            return None;
2020                        }
2021                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2022                    })
2023                })
2024                .ok()
2025                .flatten();
2026
2027            if let Some(timeout) = throttle_wait {
2028                cx.background_executor().timer(timeout).await;
2029            }
2030
2031            // If this task was cancelled before the throttle timeout expired,
2032            // do not perform a request. Also skip if another task already
2033            // proceeded since we were enqueued (duplicate).
2034            let mut is_cancelled = true;
2035            this.update(cx, |this, cx| {
2036                let project_state = this.get_or_init_project(&project, cx);
2037                let was_cancelled = project_state
2038                    .cancelled_predictions
2039                    .remove(&pending_prediction_id);
2040                if was_cancelled {
2041                    return;
2042                }
2043
2044                // Another request has been already sent since this was enqueued
2045                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2046                    return;
2047                }
2048
2049                let new_refresh = (throttle_entity, Instant::now());
2050                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2051                is_cancelled = false;
2052            })
2053            .ok();
2054            if is_cancelled {
2055                return None;
2056            }
2057
2058            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2059            let new_prediction_id = new_prediction_result
2060                .as_ref()
2061                .map(|(prediction, _)| prediction.id.clone());
2062
2063            // When a prediction completes, remove it from the pending list, and cancel
2064            // any pending predictions that were enqueued before it.
2065            this.update(cx, |this, cx| {
2066                let project_state = this.get_or_init_project(&project, cx);
2067
2068                let is_cancelled = project_state
2069                    .cancelled_predictions
2070                    .remove(&pending_prediction_id);
2071
2072                let new_current_prediction = if !is_cancelled
2073                    && let Some((prediction_result, requested_by)) = new_prediction_result
2074                {
2075                    match prediction_result.prediction {
2076                        Ok(prediction) => {
2077                            let new_prediction = CurrentEditPrediction {
2078                                requested_by,
2079                                prediction,
2080                                was_shown: false,
2081                                shown_with: None,
2082                            };
2083
2084                            if let Some(current_prediction) =
2085                                project_state.current_prediction.as_ref()
2086                            {
2087                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2088                                {
2089                                    this.reject_current_prediction(
2090                                        EditPredictionRejectReason::Replaced,
2091                                        &project,
2092                                        cx,
2093                                    );
2094
2095                                    Some(new_prediction)
2096                                } else {
2097                                    this.reject_prediction(
2098                                        new_prediction.prediction.id,
2099                                        EditPredictionRejectReason::CurrentPreferred,
2100                                        false,
2101                                        new_prediction.prediction.model_version,
2102                                        cx,
2103                                    );
2104                                    None
2105                                }
2106                            } else {
2107                                Some(new_prediction)
2108                            }
2109                        }
2110                        Err(reject_reason) => {
2111                            this.reject_prediction(
2112                                prediction_result.id,
2113                                reject_reason,
2114                                false,
2115                                None,
2116                                cx,
2117                            );
2118                            None
2119                        }
2120                    }
2121                } else {
2122                    None
2123                };
2124
2125                let project_state = this.get_or_init_project(&project, cx);
2126
2127                if let Some(new_prediction) = new_current_prediction {
2128                    project_state.current_prediction = Some(new_prediction);
2129                }
2130
2131                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2132                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2133                    if pending_prediction.id == pending_prediction_id {
2134                        pending_predictions.remove(ix);
2135                        for pending_prediction in pending_predictions.drain(0..ix) {
2136                            project_state.cancel_pending_prediction(pending_prediction, cx)
2137                        }
2138                        break;
2139                    }
2140                }
2141                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2142                cx.notify();
2143            })
2144            .ok();
2145
2146            new_prediction_id
2147        });
2148
2149        if project_state.pending_predictions.len() < max_pending_predictions {
2150            project_state.pending_predictions.push(PendingPrediction {
2151                id: pending_prediction_id,
2152                task,
2153                drop_on_cancel,
2154            });
2155        } else {
2156            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2157            project_state.pending_predictions.push(PendingPrediction {
2158                id: pending_prediction_id,
2159                task,
2160                drop_on_cancel,
2161            });
2162            project_state.cancel_pending_prediction(pending_prediction, cx);
2163        }
2164    }
2165
2166    pub fn request_prediction(
2167        &mut self,
2168        project: &Entity<Project>,
2169        active_buffer: &Entity<Buffer>,
2170        position: language::Anchor,
2171        trigger: PredictEditsRequestTrigger,
2172        cx: &mut Context<Self>,
2173    ) -> Task<Result<Option<EditPredictionResult>>> {
2174        self.request_prediction_internal(
2175            project.clone(),
2176            active_buffer.clone(),
2177            position,
2178            trigger,
2179            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2180            cx,
2181        )
2182    }
2183
2184    fn request_prediction_internal(
2185        &mut self,
2186        project: Entity<Project>,
2187        active_buffer: Entity<Buffer>,
2188        position: language::Anchor,
2189        trigger: PredictEditsRequestTrigger,
2190        allow_jump: bool,
2191        cx: &mut Context<Self>,
2192    ) -> Task<Result<Option<EditPredictionResult>>> {
2193        self.get_or_init_project(&project, cx);
2194        let project_state = self.projects.get(&project.entity_id()).unwrap();
2195        let stored_events = project_state.events(cx);
2196        let has_events = !stored_events.is_empty();
2197        let events: Vec<Arc<zeta_prompt::Event>> =
2198            stored_events.iter().map(|e| e.event.clone()).collect();
2199        let debug_tx = project_state.debug_tx.clone();
2200
2201        let snapshot = active_buffer.read(cx).snapshot();
2202        let cursor_point = position.to_point(&snapshot);
2203        let current_offset = position.to_offset(&snapshot);
2204
2205        let mut user_actions: Vec<UserActionRecord> =
2206            project_state.user_actions.iter().cloned().collect();
2207
2208        if let Some(last_action) = user_actions.last() {
2209            if last_action.buffer_id == active_buffer.entity_id()
2210                && current_offset != last_action.offset
2211            {
2212                let timestamp_epoch_ms = SystemTime::now()
2213                    .duration_since(UNIX_EPOCH)
2214                    .map(|d| d.as_millis() as u64)
2215                    .unwrap_or(0);
2216                user_actions.push(UserActionRecord {
2217                    action_type: UserActionType::CursorMovement,
2218                    buffer_id: active_buffer.entity_id(),
2219                    line_number: cursor_point.row,
2220                    offset: current_offset,
2221                    timestamp_epoch_ms,
2222                });
2223            }
2224        }
2225        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2226        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2227        let diagnostic_search_range =
2228            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2229
2230        let related_files = self.context_for_project(&project, cx);
2231
2232        let is_open_source = snapshot
2233            .file()
2234            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2235            && events.iter().all(|event| event.in_open_source_repo())
2236            && related_files.iter().all(|file| file.in_open_source_repo);
2237
2238        let can_collect_data = !cfg!(test)
2239            && is_open_source
2240            && self.is_data_collection_enabled(cx)
2241            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2242
2243        let recent_paths = project_state.recent_paths.clone();
2244
2245        let inputs = EditPredictionModelInput {
2246            project: project.clone(),
2247            buffer: active_buffer,
2248            snapshot,
2249            position,
2250            events,
2251            related_files,
2252            recent_paths,
2253            trigger,
2254            diagnostic_search_range: diagnostic_search_range,
2255            debug_tx,
2256            user_actions,
2257            can_collect_data,
2258            is_open_source,
2259        };
2260
2261        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2262
2263        let task = match self.edit_prediction_model {
2264            EditPredictionModel::Zeta => {
2265                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2266            }
2267            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2268            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2269            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2270        };
2271
2272        cx.spawn(async move |this, cx| {
2273            let prediction = task.await?;
2274
2275            // Only fall back to diagnostics-based prediction if we got a
2276            // the model had nothing to suggest for the buffer
2277            if prediction.is_none()
2278                && allow_jump
2279                && has_events
2280                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2281            {
2282                this.update(cx, |this, cx| {
2283                    this.refresh_prediction_from_diagnostics(
2284                        project,
2285                        DiagnosticSearchScope::Local,
2286                        cx,
2287                    );
2288                })?;
2289                return anyhow::Ok(None);
2290            }
2291
2292            Ok(prediction)
2293        })
2294    }
2295
2296    pub(crate) async fn next_diagnostic_location(
2297        active_buffer: Entity<Buffer>,
2298        active_buffer_snapshot: &BufferSnapshot,
2299        active_buffer_diagnostic_search_range: Range<Point>,
2300        active_buffer_cursor_point: Point,
2301        project: &Entity<Project>,
2302        cx: &mut AsyncApp,
2303    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2304        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2305            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2306            .flat_map(|(_, _, _, selections)| {
2307                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2308            })
2309            .collect();
2310
2311        let mut jump_location = active_buffer_snapshot
2312            .diagnostic_groups(None)
2313            .into_iter()
2314            .filter_map(|(_, group)| {
2315                let range = &group.entries[group.primary_ix]
2316                    .range
2317                    .to_point(&active_buffer_snapshot);
2318                if range.overlaps(&active_buffer_diagnostic_search_range) {
2319                    return None;
2320                }
2321                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2322                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2323                });
2324                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2325                    <= DIAGNOSTIC_LINES_RANGE;
2326                if near_collaborator && !near_local {
2327                    return None;
2328                }
2329                Some(range.start)
2330            })
2331            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2332            .map(|position| {
2333                (
2334                    active_buffer.clone(),
2335                    active_buffer_snapshot.anchor_before(position),
2336                )
2337            });
2338
2339        if jump_location.is_none() {
2340            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2341                let file = buffer.file()?;
2342
2343                Some(ProjectPath {
2344                    worktree_id: file.worktree_id(cx),
2345                    path: file.path().clone(),
2346                })
2347            });
2348
2349            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2350                project
2351                    .diagnostic_summaries(false, cx)
2352                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2353                    .map(|(path, _, _)| {
2354                        let shared_prefix = path
2355                            .path
2356                            .components()
2357                            .zip(
2358                                active_buffer_path
2359                                    .as_ref()
2360                                    .map(|p| p.path.components())
2361                                    .unwrap_or_default(),
2362                            )
2363                            .take_while(|(a, b)| a == b)
2364                            .count();
2365                        (path, shared_prefix)
2366                    })
2367                    .collect()
2368            });
2369
2370            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2371
2372            for (path, _) in candidates {
2373                let candidate_buffer = project
2374                    .update(cx, |project, cx| project.open_buffer(path, cx))
2375                    .await?;
2376
2377                let (has_collaborators, diagnostic_position) =
2378                    candidate_buffer.read_with(cx, |buffer, _cx| {
2379                        let snapshot = buffer.snapshot();
2380                        let has_collaborators = snapshot
2381                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2382                            .next()
2383                            .is_some();
2384                        let position = buffer
2385                            .buffer_diagnostics(None)
2386                            .into_iter()
2387                            .min_by_key(|entry| entry.diagnostic.severity)
2388                            .map(|entry| entry.range.start);
2389                        (has_collaborators, position)
2390                    });
2391
2392                if has_collaborators {
2393                    continue;
2394                }
2395
2396                if let Some(position) = diagnostic_position {
2397                    jump_location = Some((candidate_buffer, position));
2398                    break;
2399                }
2400            }
2401        }
2402
2403        anyhow::Ok(jump_location)
2404    }
2405
2406    async fn send_raw_llm_request(
2407        request: RawCompletionRequest,
2408        client: Arc<Client>,
2409        custom_url: Option<Arc<Url>>,
2410        llm_token: LlmApiToken,
2411        organization_id: Option<OrganizationId>,
2412        app_version: Version,
2413    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2414        let url = if let Some(custom_url) = custom_url {
2415            custom_url.as_ref().clone()
2416        } else {
2417            client
2418                .http_client()
2419                .build_zed_llm_url("/predict_edits/raw", &[])?
2420        };
2421
2422        Self::send_api_request(
2423            |builder| {
2424                let req = builder
2425                    .uri(url.as_ref())
2426                    .body(serde_json::to_string(&request)?.into());
2427                Ok(req?)
2428            },
2429            client,
2430            llm_token,
2431            organization_id,
2432            app_version,
2433            true,
2434        )
2435        .await
2436    }
2437
2438    pub(crate) async fn send_v3_request(
2439        input: ZetaPromptInput,
2440        client: Arc<Client>,
2441        llm_token: LlmApiToken,
2442        organization_id: Option<OrganizationId>,
2443        app_version: Version,
2444        trigger: PredictEditsRequestTrigger,
2445    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2446        let url = client
2447            .http_client()
2448            .build_zed_llm_url("/predict_edits/v3", &[])?;
2449
2450        let request = PredictEditsV3Request { input, trigger };
2451
2452        let json_bytes = serde_json::to_vec(&request)?;
2453        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2454
2455        Self::send_api_request(
2456            |builder| {
2457                let req = builder
2458                    .uri(url.as_ref())
2459                    .header("Content-Encoding", "zstd")
2460                    .body(compressed.clone().into());
2461                Ok(req?)
2462            },
2463            client,
2464            llm_token,
2465            organization_id,
2466            app_version,
2467            true,
2468        )
2469        .await
2470    }
2471
2472    async fn send_api_request<Res>(
2473        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2474        client: Arc<Client>,
2475        llm_token: LlmApiToken,
2476        organization_id: Option<OrganizationId>,
2477        app_version: Version,
2478        require_auth: bool,
2479    ) -> Result<(Res, Option<EditPredictionUsage>)>
2480    where
2481        Res: DeserializeOwned,
2482    {
2483        let http_client = client.http_client();
2484
2485        let mut token = if require_auth {
2486            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2487        } else {
2488            llm_token
2489                .acquire(&client, organization_id.clone())
2490                .await
2491                .ok()
2492        };
2493        let mut did_retry = false;
2494
2495        loop {
2496            let request_builder = http_client::Request::builder().method(Method::POST);
2497
2498            let mut request_builder = request_builder
2499                .header("Content-Type", "application/json")
2500                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2501
2502            // Only add Authorization header if we have a token
2503            if let Some(ref token_value) = token {
2504                request_builder =
2505                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2506            }
2507
2508            let request = build(request_builder)?;
2509
2510            let mut response = http_client.send(request).await?;
2511
2512            if let Some(minimum_required_version) = response
2513                .headers()
2514                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2515                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2516            {
2517                anyhow::ensure!(
2518                    app_version >= minimum_required_version,
2519                    ZedUpdateRequiredError {
2520                        minimum_version: minimum_required_version
2521                    }
2522                );
2523            }
2524
2525            if response.status().is_success() {
2526                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2527
2528                let mut body = Vec::new();
2529                response.body_mut().read_to_end(&mut body).await?;
2530                return Ok((serde_json::from_slice(&body)?, usage));
2531            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2532                did_retry = true;
2533                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2534            } else {
2535                let mut body = String::new();
2536                response.body_mut().read_to_string(&mut body).await?;
2537                anyhow::bail!(
2538                    "Request failed with status: {:?}\nBody: {}",
2539                    response.status(),
2540                    body
2541                );
2542            }
2543        }
2544    }
2545
2546    pub fn refresh_context(
2547        &mut self,
2548        project: &Entity<Project>,
2549        buffer: &Entity<language::Buffer>,
2550        cursor_position: language::Anchor,
2551        cx: &mut Context<Self>,
2552    ) {
2553        self.get_or_init_project(project, cx)
2554            .context
2555            .update(cx, |store, cx| {
2556                store.refresh(buffer.clone(), cursor_position, cx);
2557            });
2558    }
2559
2560    #[cfg(feature = "cli-support")]
2561    pub fn set_context_for_buffer(
2562        &mut self,
2563        project: &Entity<Project>,
2564        related_files: Vec<RelatedFile>,
2565        cx: &mut Context<Self>,
2566    ) {
2567        self.get_or_init_project(project, cx)
2568            .context
2569            .update(cx, |store, cx| {
2570                store.set_related_files(related_files, cx);
2571            });
2572    }
2573
2574    #[cfg(feature = "cli-support")]
2575    pub fn set_recent_paths_for_project(
2576        &mut self,
2577        project: &Entity<Project>,
2578        paths: impl IntoIterator<Item = project::ProjectPath>,
2579        cx: &mut Context<Self>,
2580    ) {
2581        let project_state = self.get_or_init_project(project, cx);
2582        project_state.recent_paths = paths.into_iter().collect();
2583    }
2584
2585    fn is_file_open_source(
2586        &self,
2587        project: &Entity<Project>,
2588        file: &Arc<dyn File>,
2589        cx: &App,
2590    ) -> bool {
2591        if !file.is_local() || file.is_private() {
2592            return false;
2593        }
2594        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2595            return false;
2596        };
2597        project_state
2598            .license_detection_watchers
2599            .get(&file.worktree_id(cx))
2600            .as_ref()
2601            .is_some_and(|watcher| watcher.is_project_open_source())
2602    }
2603
2604    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2605        self.data_collection_choice.is_enabled(cx)
2606    }
2607
2608    fn load_data_collection_choice() -> DataCollectionChoice {
2609        let choice = KEY_VALUE_STORE
2610            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2611            .log_err()
2612            .flatten();
2613
2614        match choice.as_deref() {
2615            Some("true") => DataCollectionChoice::Enabled,
2616            Some("false") => DataCollectionChoice::Disabled,
2617            Some(_) => {
2618                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2619                DataCollectionChoice::NotAnswered
2620            }
2621            None => DataCollectionChoice::NotAnswered,
2622        }
2623    }
2624
2625    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2626        self.data_collection_choice = self.data_collection_choice.toggle();
2627        let new_choice = self.data_collection_choice;
2628        let is_enabled = new_choice.is_enabled(cx);
2629        db::write_and_log(cx, move || {
2630            KEY_VALUE_STORE.write_kvp(
2631                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2632                is_enabled.to_string(),
2633            )
2634        });
2635    }
2636
2637    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2638        self.shown_predictions.iter()
2639    }
2640
2641    pub fn shown_completions_len(&self) -> usize {
2642        self.shown_predictions.len()
2643    }
2644
2645    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2646        self.rated_predictions.contains(id)
2647    }
2648
2649    pub fn rate_prediction(
2650        &mut self,
2651        prediction: &EditPrediction,
2652        rating: EditPredictionRating,
2653        feedback: String,
2654        cx: &mut Context<Self>,
2655    ) {
2656        let organization = self.user_store.read(cx).current_organization();
2657
2658        self.rated_predictions.insert(prediction.id.clone());
2659
2660        cx.background_spawn({
2661            let client = self.client.clone();
2662            let prediction_id = prediction.id.to_string();
2663            let inputs = serde_json::to_value(&prediction.inputs);
2664            let output = prediction
2665                .edit_preview
2666                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2667            async move {
2668                client
2669                    .cloud_client()
2670                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2671                        organization_id: organization.map(|organization| organization.id.clone()),
2672                        request_id: prediction_id,
2673                        rating: match rating {
2674                            EditPredictionRating::Positive => "positive".to_string(),
2675                            EditPredictionRating::Negative => "negative".to_string(),
2676                        },
2677                        inputs: inputs?,
2678                        output,
2679                        feedback,
2680                    })
2681                    .await?;
2682
2683                anyhow::Ok(())
2684            }
2685        })
2686        .detach_and_log_err(cx);
2687
2688        cx.notify();
2689    }
2690}
2691
2692fn merge_trailing_events_if_needed(
2693    events: &mut VecDeque<StoredEvent>,
2694    end_snapshot: &TextBufferSnapshot,
2695    latest_snapshot: &TextBufferSnapshot,
2696    latest_edit_range: &Range<Anchor>,
2697) {
2698    if let Some(last_event) = events.back() {
2699        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2700            return;
2701        }
2702    }
2703
2704    let mut next_old_event = None;
2705    let mut mergeable_count = 0;
2706    for old_event in events.iter().rev() {
2707        if let Some(next_old_event) = &next_old_event
2708            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2709        {
2710            break;
2711        }
2712        mergeable_count += 1;
2713        next_old_event = Some(old_event);
2714    }
2715
2716    if mergeable_count <= 1 {
2717        return;
2718    }
2719
2720    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2721    let oldest_event = events_to_merge.peek().unwrap();
2722    let oldest_snapshot = oldest_event.old_snapshot.clone();
2723
2724    if let Some((diff, edited_range)) =
2725        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2726    {
2727        let merged_event = match oldest_event.event.as_ref() {
2728            zeta_prompt::Event::BufferChange {
2729                old_path,
2730                path,
2731                in_open_source_repo,
2732                ..
2733            } => StoredEvent {
2734                event: Arc::new(zeta_prompt::Event::BufferChange {
2735                    old_path: old_path.clone(),
2736                    path: path.clone(),
2737                    diff,
2738                    in_open_source_repo: *in_open_source_repo,
2739                    predicted: events_to_merge.all(|e| {
2740                        matches!(
2741                            e.event.as_ref(),
2742                            zeta_prompt::Event::BufferChange {
2743                                predicted: true,
2744                                ..
2745                            }
2746                        )
2747                    }),
2748                }),
2749                old_snapshot: oldest_snapshot.clone(),
2750                edit_range: end_snapshot.anchor_before(edited_range.start)
2751                    ..end_snapshot.anchor_before(edited_range.end),
2752            },
2753        };
2754        events.truncate(events.len() - mergeable_count);
2755        events.push_back(merged_event);
2756    }
2757}
2758
2759#[derive(Error, Debug)]
2760#[error(
2761    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2762)]
2763pub struct ZedUpdateRequiredError {
2764    minimum_version: Version,
2765}
2766
2767#[derive(Debug, Clone, Copy)]
2768pub enum DataCollectionChoice {
2769    NotAnswered,
2770    Enabled,
2771    Disabled,
2772}
2773
2774impl DataCollectionChoice {
2775    pub fn is_enabled(self, cx: &App) -> bool {
2776        if cx.is_staff() {
2777            return true;
2778        }
2779        match self {
2780            Self::Enabled => true,
2781            Self::NotAnswered | Self::Disabled => false,
2782        }
2783    }
2784
2785    #[must_use]
2786    pub fn toggle(&self) -> DataCollectionChoice {
2787        match self {
2788            Self::Enabled => Self::Disabled,
2789            Self::Disabled => Self::Enabled,
2790            Self::NotAnswered => Self::Enabled,
2791        }
2792    }
2793}
2794
2795impl From<bool> for DataCollectionChoice {
2796    fn from(value: bool) -> Self {
2797        match value {
2798            true => DataCollectionChoice::Enabled,
2799            false => DataCollectionChoice::Disabled,
2800        }
2801    }
2802}
2803
2804struct ZedPredictUpsell;
2805
2806impl Dismissable for ZedPredictUpsell {
2807    const KEY: &'static str = "dismissed-edit-predict-upsell";
2808
2809    fn dismissed() -> bool {
2810        // To make this backwards compatible with older versions of Zed, we
2811        // check if the user has seen the previous Edit Prediction Onboarding
2812        // before, by checking the data collection choice which was written to
2813        // the database once the user clicked on "Accept and Enable"
2814        if KEY_VALUE_STORE
2815            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2816            .log_err()
2817            .is_some_and(|s| s.is_some())
2818        {
2819            return true;
2820        }
2821
2822        KEY_VALUE_STORE
2823            .read_kvp(Self::KEY)
2824            .log_err()
2825            .is_some_and(|s| s.is_some())
2826    }
2827}
2828
2829pub fn should_show_upsell_modal() -> bool {
2830    !ZedPredictUpsell::dismissed()
2831}
2832
2833pub fn init(cx: &mut App) {
2834    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2835        workspace.register_action(
2836            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2837                ZedPredictModal::toggle(
2838                    workspace,
2839                    workspace.user_store().clone(),
2840                    workspace.client().clone(),
2841                    window,
2842                    cx,
2843                )
2844            },
2845        );
2846
2847        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2848            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2849                settings
2850                    .project
2851                    .all_languages
2852                    .edit_predictions
2853                    .get_or_insert_default()
2854                    .provider = Some(EditPredictionProvider::None)
2855            });
2856        });
2857        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2858            EditPredictionStore::try_global(cx).and_then(|store| {
2859                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2860            })
2861        }
2862
2863        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2864            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2865                copilot_ui::initiate_sign_in(copilot, window, cx);
2866            }
2867        });
2868        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2869            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2870                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2871            }
2872        });
2873        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2874            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2875                copilot_ui::initiate_sign_out(copilot, window, cx);
2876            }
2877        });
2878    })
2879    .detach();
2880}