edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::onboarding_modal::ZedPredictModal;
  81pub use crate::prediction::EditPrediction;
  82pub use crate::prediction::EditPredictionId;
  83use crate::prediction::EditPredictionResult;
  84pub use crate::sweep_ai::SweepAi;
  85pub use capture_example::capture_example;
  86pub use language_model::ApiKeyState;
  87pub use telemetry_events::EditPredictionRating;
  88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  89
  90actions!(
  91    edit_prediction,
  92    [
  93        /// Resets the edit prediction onboarding state.
  94        ResetOnboarding,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 107const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 108const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 109
 110pub struct Zeta2FeatureFlag;
 111pub struct EditPredictionJumpsFeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115}
 116
 117impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 118    const NAME: &'static str = "edit_prediction_jumps";
 119}
 120
 121#[derive(Clone)]
 122struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 123
 124impl Global for EditPredictionStoreGlobal {}
 125
 126/// Configuration for using the raw Zeta2 endpoint.
 127/// When set, the client uses the raw endpoint and constructs the prompt itself.
 128/// The version is also used as the Baseten environment name (lowercased).
 129#[derive(Clone)]
 130pub struct Zeta2RawConfig {
 131    pub model_id: Option<String>,
 132    pub format: ZetaFormat,
 133}
 134
 135pub struct EditPredictionStore {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _llm_token_subscription: Subscription,
 140    _fetch_experiments_task: Task<()>,
 141    projects: HashMap<EntityId, ProjectState>,
 142    update_required: bool,
 143    edit_prediction_model: EditPredictionModel,
 144    zeta2_raw_config: Option<Zeta2RawConfig>,
 145    preferred_experiment: Option<String>,
 146    available_experiments: Vec<String>,
 147    pub sweep_ai: SweepAi,
 148    pub mercury: Mercury,
 149    data_collection_choice: DataCollectionChoice,
 150    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 151    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 152    shown_predictions: VecDeque<EditPrediction>,
 153    rated_predictions: HashSet<EditPredictionId>,
 154    #[cfg(test)]
 155    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 156}
 157
 158#[derive(Copy, Clone, PartialEq, Eq)]
 159pub enum EditPredictionModel {
 160    Zeta,
 161    Fim { format: EditPredictionPromptFormat },
 162    Sweep,
 163    Mercury,
 164}
 165
 166#[derive(Clone)]
 167pub struct EditPredictionModelInput {
 168    project: Entity<Project>,
 169    buffer: Entity<Buffer>,
 170    snapshot: BufferSnapshot,
 171    position: Anchor,
 172    events: Vec<Arc<zeta_prompt::Event>>,
 173    related_files: Vec<RelatedFile>,
 174    recent_paths: VecDeque<ProjectPath>,
 175    trigger: PredictEditsRequestTrigger,
 176    diagnostic_search_range: Range<Point>,
 177    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 178    can_collect_data: bool,
 179    is_open_source: bool,
 180    pub user_actions: Vec<UserActionRecord>,
 181}
 182
 183#[derive(Debug)]
 184pub enum DebugEvent {
 185    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 186    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 187    EditPredictionStarted(EditPredictionStartedDebugEvent),
 188    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 189}
 190
 191#[derive(Debug)]
 192pub struct ContextRetrievalStartedDebugEvent {
 193    pub project_entity_id: EntityId,
 194    pub timestamp: Instant,
 195    pub search_prompt: String,
 196}
 197
 198#[derive(Debug)]
 199pub struct ContextRetrievalFinishedDebugEvent {
 200    pub project_entity_id: EntityId,
 201    pub timestamp: Instant,
 202    pub metadata: Vec<(&'static str, SharedString)>,
 203}
 204
 205#[derive(Debug)]
 206pub struct EditPredictionStartedDebugEvent {
 207    pub buffer: WeakEntity<Buffer>,
 208    pub position: Anchor,
 209    pub prompt: Option<String>,
 210}
 211
 212#[derive(Debug)]
 213pub struct EditPredictionFinishedDebugEvent {
 214    pub buffer: WeakEntity<Buffer>,
 215    pub position: Anchor,
 216    pub model_output: Option<String>,
 217}
 218
 219const USER_ACTION_HISTORY_SIZE: usize = 16;
 220
 221#[derive(Clone, Debug)]
 222pub struct UserActionRecord {
 223    pub action_type: UserActionType,
 224    pub buffer_id: EntityId,
 225    pub line_number: u32,
 226    pub offset: usize,
 227    pub timestamp_epoch_ms: u64,
 228}
 229
 230#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 231pub enum UserActionType {
 232    InsertChar,
 233    InsertSelection,
 234    DeleteChar,
 235    DeleteSelection,
 236    CursorMovement,
 237}
 238
 239/// An event with associated metadata for reconstructing buffer state.
 240#[derive(Clone)]
 241pub struct StoredEvent {
 242    pub event: Arc<zeta_prompt::Event>,
 243    pub old_snapshot: TextBufferSnapshot,
 244    pub edit_range: Range<Anchor>,
 245}
 246
 247impl StoredEvent {
 248    fn can_merge(
 249        &self,
 250        next_old_event: &&&StoredEvent,
 251        new_snapshot: &TextBufferSnapshot,
 252        last_edit_range: &Range<Anchor>,
 253    ) -> bool {
 254        // Events must be for the same buffer
 255        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 256            return false;
 257        }
 258        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 259            return false;
 260        }
 261
 262        let a_is_predicted = matches!(
 263            self.event.as_ref(),
 264            zeta_prompt::Event::BufferChange {
 265                predicted: true,
 266                ..
 267            }
 268        );
 269        let b_is_predicted = matches!(
 270            next_old_event.event.as_ref(),
 271            zeta_prompt::Event::BufferChange {
 272                predicted: true,
 273                ..
 274            }
 275        );
 276
 277        // If events come from the same source (both predicted or both manual) then
 278        // we would have coalesced them already.
 279        if a_is_predicted == b_is_predicted {
 280            return false;
 281        }
 282
 283        let left_range = self.edit_range.to_point(new_snapshot);
 284        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 285        let latest_range = last_edit_range.to_point(&new_snapshot);
 286
 287        // Events near to the latest edit are not merged if their sources differ.
 288        if lines_between_ranges(&left_range, &latest_range)
 289            .min(lines_between_ranges(&right_range, &latest_range))
 290            <= CHANGE_GROUPING_LINE_SPAN
 291        {
 292            return false;
 293        }
 294
 295        // Events that are distant from each other are not merged.
 296        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 297            return false;
 298        }
 299
 300        true
 301    }
 302}
 303
 304fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 305    if left.start > right.end {
 306        return left.start.row - right.end.row;
 307    }
 308    if right.start > left.end {
 309        return right.start.row - left.end.row;
 310    }
 311    0
 312}
 313
 314struct ProjectState {
 315    events: VecDeque<StoredEvent>,
 316    last_event: Option<LastEvent>,
 317    recent_paths: VecDeque<ProjectPath>,
 318    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 319    current_prediction: Option<CurrentEditPrediction>,
 320    next_pending_prediction_id: usize,
 321    pending_predictions: ArrayVec<PendingPrediction, 2>,
 322    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 323    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 324    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 325    cancelled_predictions: HashSet<usize>,
 326    context: Entity<RelatedExcerptStore>,
 327    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 328    user_actions: VecDeque<UserActionRecord>,
 329    _subscriptions: [gpui::Subscription; 2],
 330    copilot: Option<Entity<Copilot>>,
 331}
 332
 333impl ProjectState {
 334    fn record_user_action(&mut self, action: UserActionRecord) {
 335        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 336            self.user_actions.pop_front();
 337        }
 338        self.user_actions.push_back(action);
 339    }
 340
 341    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 342        self.events
 343            .iter()
 344            .cloned()
 345            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 346                let (one, two) = event.split_by_pause();
 347                let one = one.finalize(&self.license_detection_watchers, cx);
 348                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 349                one.into_iter().chain(two)
 350            }))
 351            .collect()
 352    }
 353
 354    fn cancel_pending_prediction(
 355        &mut self,
 356        pending_prediction: PendingPrediction,
 357        cx: &mut Context<EditPredictionStore>,
 358    ) {
 359        self.cancelled_predictions.insert(pending_prediction.id);
 360
 361        if pending_prediction.drop_on_cancel {
 362            drop(pending_prediction.task);
 363        } else {
 364            cx.spawn(async move |this, cx| {
 365                let Some(prediction_id) = pending_prediction.task.await else {
 366                    return;
 367                };
 368
 369                this.update(cx, |this, cx| {
 370                    this.reject_prediction(
 371                        prediction_id,
 372                        EditPredictionRejectReason::Canceled,
 373                        false,
 374                        None,
 375                        cx,
 376                    );
 377                })
 378                .ok();
 379            })
 380            .detach()
 381        }
 382    }
 383
 384    fn active_buffer(
 385        &self,
 386        project: &Entity<Project>,
 387        cx: &App,
 388    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 389        let project = project.read(cx);
 390        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 391        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 392        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 393        Some((active_buffer, registered_buffer.last_position))
 394    }
 395}
 396
 397#[derive(Debug, Clone)]
 398struct CurrentEditPrediction {
 399    pub requested_by: PredictionRequestedBy,
 400    pub prediction: EditPrediction,
 401    pub was_shown: bool,
 402    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 403}
 404
 405impl CurrentEditPrediction {
 406    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 407        let Some(new_edits) = self
 408            .prediction
 409            .interpolate(&self.prediction.buffer.read(cx))
 410        else {
 411            return false;
 412        };
 413
 414        if self.prediction.buffer != old_prediction.prediction.buffer {
 415            return true;
 416        }
 417
 418        let Some(old_edits) = old_prediction
 419            .prediction
 420            .interpolate(&old_prediction.prediction.buffer.read(cx))
 421        else {
 422            return true;
 423        };
 424
 425        let requested_by_buffer_id = self.requested_by.buffer_id();
 426
 427        // This reduces the occurrence of UI thrash from replacing edits
 428        //
 429        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 430        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 431            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 432            && old_edits.len() == 1
 433            && new_edits.len() == 1
 434        {
 435            let (old_range, old_text) = &old_edits[0];
 436            let (new_range, new_text) = &new_edits[0];
 437            new_range == old_range && new_text.starts_with(old_text.as_ref())
 438        } else {
 439            true
 440        }
 441    }
 442}
 443
 444#[derive(Debug, Clone)]
 445enum PredictionRequestedBy {
 446    DiagnosticsUpdate,
 447    Buffer(EntityId),
 448}
 449
 450impl PredictionRequestedBy {
 451    pub fn buffer_id(&self) -> Option<EntityId> {
 452        match self {
 453            PredictionRequestedBy::DiagnosticsUpdate => None,
 454            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 455        }
 456    }
 457}
 458
 459const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 460
 461#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 462pub enum DiagnosticSearchScope {
 463    Local,
 464    Global,
 465}
 466
 467#[derive(Debug)]
 468struct PendingPrediction {
 469    id: usize,
 470    task: Task<Option<EditPredictionId>>,
 471    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 472    /// If false, the task is awaited to completion so rejection can be reported.
 473    drop_on_cancel: bool,
 474}
 475
 476/// A prediction from the perspective of a buffer.
 477#[derive(Debug)]
 478enum BufferEditPrediction<'a> {
 479    Local { prediction: &'a EditPrediction },
 480    Jump { prediction: &'a EditPrediction },
 481}
 482
 483#[cfg(test)]
 484impl std::ops::Deref for BufferEditPrediction<'_> {
 485    type Target = EditPrediction;
 486
 487    fn deref(&self) -> &Self::Target {
 488        match self {
 489            BufferEditPrediction::Local { prediction } => prediction,
 490            BufferEditPrediction::Jump { prediction } => prediction,
 491        }
 492    }
 493}
 494
 495#[derive(Clone)]
 496struct PendingSettledPrediction {
 497    request_id: EditPredictionId,
 498    editable_anchor_range: Range<Anchor>,
 499    enqueued_at: Instant,
 500    last_edit_at: Instant,
 501}
 502
 503struct RegisteredBuffer {
 504    file: Option<Arc<dyn File>>,
 505    snapshot: TextBufferSnapshot,
 506    pending_predictions: Vec<PendingSettledPrediction>,
 507    last_position: Option<Anchor>,
 508    _subscriptions: [gpui::Subscription; 2],
 509}
 510
 511#[derive(Clone)]
 512struct LastEvent {
 513    old_snapshot: TextBufferSnapshot,
 514    new_snapshot: TextBufferSnapshot,
 515    old_file: Option<Arc<dyn File>>,
 516    new_file: Option<Arc<dyn File>>,
 517    edit_range: Option<Range<Anchor>>,
 518    predicted: bool,
 519    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 520    last_edit_time: Option<Instant>,
 521}
 522
 523impl LastEvent {
 524    pub fn finalize(
 525        &self,
 526        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 527        cx: &App,
 528    ) -> Option<StoredEvent> {
 529        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 530        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 531
 532        let in_open_source_repo =
 533            [self.new_file.as_ref(), self.old_file.as_ref()]
 534                .iter()
 535                .all(|file| {
 536                    file.is_some_and(|file| {
 537                        license_detection_watchers
 538                            .get(&file.worktree_id(cx))
 539                            .is_some_and(|watcher| watcher.is_project_open_source())
 540                    })
 541                });
 542
 543        let (diff, edit_range) =
 544            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 545
 546        if path == old_path && diff.is_empty() {
 547            None
 548        } else {
 549            Some(StoredEvent {
 550                event: Arc::new(zeta_prompt::Event::BufferChange {
 551                    old_path,
 552                    path,
 553                    diff,
 554                    in_open_source_repo,
 555                    predicted: self.predicted,
 556                }),
 557                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 558                    ..self.new_snapshot.anchor_before(edit_range.end),
 559                old_snapshot: self.old_snapshot.clone(),
 560            })
 561        }
 562    }
 563
 564    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 565        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 566            return (self.clone(), None);
 567        };
 568
 569        let before = LastEvent {
 570            old_snapshot: self.old_snapshot.clone(),
 571            new_snapshot: boundary_snapshot.clone(),
 572            old_file: self.old_file.clone(),
 573            new_file: self.new_file.clone(),
 574            edit_range: None,
 575            predicted: self.predicted,
 576            snapshot_after_last_editing_pause: None,
 577            last_edit_time: self.last_edit_time,
 578        };
 579
 580        let after = LastEvent {
 581            old_snapshot: boundary_snapshot.clone(),
 582            new_snapshot: self.new_snapshot.clone(),
 583            old_file: self.old_file.clone(),
 584            new_file: self.new_file.clone(),
 585            edit_range: None,
 586            predicted: self.predicted,
 587            snapshot_after_last_editing_pause: None,
 588            last_edit_time: self.last_edit_time,
 589        };
 590
 591        (before, Some(after))
 592    }
 593}
 594
 595pub(crate) fn compute_diff_between_snapshots(
 596    old_snapshot: &TextBufferSnapshot,
 597    new_snapshot: &TextBufferSnapshot,
 598) -> Option<(String, Range<Point>)> {
 599    let edits: Vec<Edit<usize>> = new_snapshot
 600        .edits_since::<usize>(&old_snapshot.version)
 601        .collect();
 602
 603    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 604
 605    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 606    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 607    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 608    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 609
 610    const CONTEXT_LINES: u32 = 3;
 611
 612    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 613    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 614    let old_context_end_row =
 615        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 616    let new_context_end_row =
 617        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 618
 619    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 620    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 621    let old_end_line_offset = old_snapshot
 622        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 623    let new_end_line_offset = new_snapshot
 624        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 625    let old_edit_range = old_start_line_offset..old_end_line_offset;
 626    let new_edit_range = new_start_line_offset..new_end_line_offset;
 627
 628    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 629    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 630
 631    let diff = language::unified_diff_with_offsets(
 632        &old_region_text,
 633        &new_region_text,
 634        old_context_start_row,
 635        new_context_start_row,
 636    );
 637
 638    Some((diff, new_start_point..new_end_point))
 639}
 640
 641fn buffer_path_with_id_fallback(
 642    file: Option<&Arc<dyn File>>,
 643    snapshot: &TextBufferSnapshot,
 644    cx: &App,
 645) -> Arc<Path> {
 646    if let Some(file) = file {
 647        file.full_path(cx).into()
 648    } else {
 649        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 650    }
 651}
 652
 653impl EditPredictionStore {
 654    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 655        cx.try_global::<EditPredictionStoreGlobal>()
 656            .map(|global| global.0.clone())
 657    }
 658
 659    pub fn global(
 660        client: &Arc<Client>,
 661        user_store: &Entity<UserStore>,
 662        cx: &mut App,
 663    ) -> Entity<Self> {
 664        cx.try_global::<EditPredictionStoreGlobal>()
 665            .map(|global| global.0.clone())
 666            .unwrap_or_else(|| {
 667                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 668                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 669                ep_store
 670            })
 671    }
 672
 673    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 674        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 675        let data_collection_choice = Self::load_data_collection_choice();
 676
 677        let llm_token = LlmApiToken::default();
 678
 679        let (reject_tx, reject_rx) = mpsc::unbounded();
 680        cx.background_spawn({
 681            let client = client.clone();
 682            let llm_token = llm_token.clone();
 683            let app_version = AppVersion::global(cx);
 684            let background_executor = cx.background_executor().clone();
 685            async move {
 686                Self::handle_rejected_predictions(
 687                    reject_rx,
 688                    client,
 689                    llm_token,
 690                    app_version,
 691                    background_executor,
 692                )
 693                .await
 694            }
 695        })
 696        .detach();
 697
 698        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 699        cx.spawn(async move |this, cx| {
 700            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 701        })
 702        .detach();
 703
 704        let mut current_user = user_store.read(cx).watch_current_user();
 705        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 706            while current_user.borrow().is_none() {
 707                current_user.next().await;
 708            }
 709            this.update(cx, |this, cx| {
 710                this.refresh_available_experiments(cx);
 711            })
 712            .log_err();
 713        });
 714
 715        let this = Self {
 716            projects: HashMap::default(),
 717            client,
 718            user_store,
 719            llm_token,
 720            _fetch_experiments_task: fetch_experiments_task,
 721            _llm_token_subscription: cx.subscribe(
 722                &refresh_llm_token_listener,
 723                |this, _listener, _event, cx| {
 724                    let client = this.client.clone();
 725                    let llm_token = this.llm_token.clone();
 726                    cx.spawn(async move |_this, _cx| {
 727                        llm_token.refresh(&client).await?;
 728                        anyhow::Ok(())
 729                    })
 730                    .detach_and_log_err(cx);
 731                },
 732            ),
 733            update_required: false,
 734            edit_prediction_model: EditPredictionModel::Zeta,
 735            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 736            preferred_experiment: None,
 737            available_experiments: Vec::new(),
 738            sweep_ai: SweepAi::new(cx),
 739            mercury: Mercury::new(cx),
 740
 741            data_collection_choice,
 742            reject_predictions_tx: reject_tx,
 743            settled_predictions_tx,
 744            rated_predictions: Default::default(),
 745            shown_predictions: Default::default(),
 746            #[cfg(test)]
 747            settled_event_callback: None,
 748        };
 749
 750        this
 751    }
 752
 753    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 754        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 755        let format = ZetaFormat::parse(&version_str).ok()?;
 756        let model_id = env::var("ZED_ZETA_MODEL").ok();
 757        Some(Zeta2RawConfig { model_id, format })
 758    }
 759
 760    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 761        self.edit_prediction_model = model;
 762    }
 763
 764    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 765        self.zeta2_raw_config = Some(config);
 766    }
 767
 768    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 769        self.zeta2_raw_config.as_ref()
 770    }
 771
 772    pub fn preferred_experiment(&self) -> Option<&str> {
 773        self.preferred_experiment.as_deref()
 774    }
 775
 776    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 777        self.preferred_experiment = experiment;
 778    }
 779
 780    pub fn available_experiments(&self) -> &[String] {
 781        &self.available_experiments
 782    }
 783
 784    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 785        let client = self.client.clone();
 786        let llm_token = self.llm_token.clone();
 787        let app_version = AppVersion::global(cx);
 788        cx.spawn(async move |this, cx| {
 789            let experiments = cx
 790                .background_spawn(async move {
 791                    let http_client = client.http_client();
 792                    let token = llm_token.acquire(&client).await?;
 793                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 794                    let request = http_client::Request::builder()
 795                        .method(Method::GET)
 796                        .uri(url.as_ref())
 797                        .header("Authorization", format!("Bearer {}", token))
 798                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 799                        .body(Default::default())?;
 800                    let mut response = http_client.send(request).await?;
 801                    if response.status().is_success() {
 802                        let mut body = Vec::new();
 803                        response.body_mut().read_to_end(&mut body).await?;
 804                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 805                        Ok(experiments)
 806                    } else {
 807                        let mut body = String::new();
 808                        response.body_mut().read_to_string(&mut body).await?;
 809                        anyhow::bail!(
 810                            "Failed to fetch experiments: {:?}\nBody: {}",
 811                            response.status(),
 812                            body
 813                        );
 814                    }
 815                })
 816                .await?;
 817            this.update(cx, |this, cx| {
 818                this.available_experiments = experiments;
 819                cx.notify();
 820            })?;
 821            anyhow::Ok(())
 822        })
 823        .detach_and_log_err(cx);
 824    }
 825
 826    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 827        use ui::IconName;
 828        match self.edit_prediction_model {
 829            EditPredictionModel::Sweep => {
 830                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 831                    .with_disabled(IconName::SweepAiDisabled)
 832                    .with_up(IconName::SweepAiUp)
 833                    .with_down(IconName::SweepAiDown)
 834                    .with_error(IconName::SweepAiError)
 835            }
 836            EditPredictionModel::Mercury => {
 837                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 838            }
 839            EditPredictionModel::Zeta => {
 840                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 841                    .with_disabled(IconName::ZedPredictDisabled)
 842                    .with_up(IconName::ZedPredictUp)
 843                    .with_down(IconName::ZedPredictDown)
 844                    .with_error(IconName::ZedPredictError)
 845            }
 846            EditPredictionModel::Fim { .. } => {
 847                let settings = &all_language_settings(None, cx).edit_predictions;
 848                match settings.provider {
 849                    EditPredictionProvider::Ollama => {
 850                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 851                    }
 852                    _ => {
 853                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 854                    }
 855                }
 856            }
 857        }
 858    }
 859
 860    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 861        self.sweep_ai.api_token.read(cx).has_key()
 862    }
 863
 864    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 865        self.mercury.api_token.read(cx).has_key()
 866    }
 867
 868    pub fn clear_history(&mut self) {
 869        for project_state in self.projects.values_mut() {
 870            project_state.events.clear();
 871            project_state.last_event.take();
 872        }
 873    }
 874
 875    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 876        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 877            project_state.events.clear();
 878            project_state.last_event.take();
 879        }
 880    }
 881
 882    pub fn edit_history_for_project(
 883        &self,
 884        project: &Entity<Project>,
 885        cx: &App,
 886    ) -> Vec<StoredEvent> {
 887        self.projects
 888            .get(&project.entity_id())
 889            .map(|project_state| project_state.events(cx))
 890            .unwrap_or_default()
 891    }
 892
 893    pub fn context_for_project<'a>(
 894        &'a self,
 895        project: &Entity<Project>,
 896        cx: &'a mut App,
 897    ) -> Vec<RelatedFile> {
 898        self.projects
 899            .get(&project.entity_id())
 900            .map(|project_state| {
 901                project_state.context.update(cx, |context, cx| {
 902                    context
 903                        .related_files_with_buffers(cx)
 904                        .map(|(mut related_file, buffer)| {
 905                            related_file.in_open_source_repo = buffer
 906                                .read(cx)
 907                                .file()
 908                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 909                            related_file
 910                        })
 911                        .collect()
 912                })
 913            })
 914            .unwrap_or_default()
 915    }
 916
 917    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 918        self.projects
 919            .get(&project.entity_id())
 920            .and_then(|project| project.copilot.clone())
 921    }
 922
 923    pub fn start_copilot_for_project(
 924        &mut self,
 925        project: &Entity<Project>,
 926        cx: &mut Context<Self>,
 927    ) -> Option<Entity<Copilot>> {
 928        if DisableAiSettings::get(None, cx).disable_ai {
 929            return None;
 930        }
 931        let state = self.get_or_init_project(project, cx);
 932
 933        if state.copilot.is_some() {
 934            return state.copilot.clone();
 935        }
 936        let _project = project.clone();
 937        let project = project.read(cx);
 938
 939        let node = project.node_runtime().cloned();
 940        if let Some(node) = node {
 941            let next_id = project.languages().next_language_server_id();
 942            let fs = project.fs().clone();
 943
 944            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 945            state.copilot = Some(copilot.clone());
 946            Some(copilot)
 947        } else {
 948            None
 949        }
 950    }
 951
 952    pub fn context_for_project_with_buffers<'a>(
 953        &'a self,
 954        project: &Entity<Project>,
 955        cx: &'a mut App,
 956    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 957        self.projects
 958            .get(&project.entity_id())
 959            .map(|project| {
 960                project.context.update(cx, |context, cx| {
 961                    context.related_files_with_buffers(cx).collect()
 962                })
 963            })
 964            .unwrap_or_default()
 965    }
 966
 967    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 968        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
 969            self.user_store.read(cx).edit_prediction_usage()
 970        } else {
 971            None
 972        }
 973    }
 974
 975    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 976        self.get_or_init_project(project, cx);
 977    }
 978
 979    pub fn register_buffer(
 980        &mut self,
 981        buffer: &Entity<Buffer>,
 982        project: &Entity<Project>,
 983        cx: &mut Context<Self>,
 984    ) {
 985        let project_state = self.get_or_init_project(project, cx);
 986        Self::register_buffer_impl(project_state, buffer, project, cx);
 987    }
 988
 989    fn get_or_init_project(
 990        &mut self,
 991        project: &Entity<Project>,
 992        cx: &mut Context<Self>,
 993    ) -> &mut ProjectState {
 994        let entity_id = project.entity_id();
 995        self.projects
 996            .entry(entity_id)
 997            .or_insert_with(|| ProjectState {
 998                context: {
 999                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1000                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1001                        this.handle_excerpt_store_event(entity_id, event);
1002                    })
1003                    .detach();
1004                    related_excerpt_store
1005                },
1006                events: VecDeque::new(),
1007                last_event: None,
1008                recent_paths: VecDeque::new(),
1009                debug_tx: None,
1010                registered_buffers: HashMap::default(),
1011                current_prediction: None,
1012                cancelled_predictions: HashSet::default(),
1013                pending_predictions: ArrayVec::new(),
1014                next_pending_prediction_id: 0,
1015                last_edit_prediction_refresh: None,
1016                last_jump_prediction_refresh: None,
1017                license_detection_watchers: HashMap::default(),
1018                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1019                _subscriptions: [
1020                    cx.subscribe(&project, Self::handle_project_event),
1021                    cx.observe_release(&project, move |this, _, cx| {
1022                        this.projects.remove(&entity_id);
1023                        cx.notify();
1024                    }),
1025                ],
1026                copilot: None,
1027            })
1028    }
1029
1030    pub fn remove_project(&mut self, project: &Entity<Project>) {
1031        self.projects.remove(&project.entity_id());
1032    }
1033
1034    fn handle_excerpt_store_event(
1035        &mut self,
1036        project_entity_id: EntityId,
1037        event: &RelatedExcerptStoreEvent,
1038    ) {
1039        if let Some(project_state) = self.projects.get(&project_entity_id) {
1040            if let Some(debug_tx) = project_state.debug_tx.clone() {
1041                match event {
1042                    RelatedExcerptStoreEvent::StartedRefresh => {
1043                        debug_tx
1044                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1045                                ContextRetrievalStartedDebugEvent {
1046                                    project_entity_id: project_entity_id,
1047                                    timestamp: Instant::now(),
1048                                    search_prompt: String::new(),
1049                                },
1050                            ))
1051                            .ok();
1052                    }
1053                    RelatedExcerptStoreEvent::FinishedRefresh {
1054                        cache_hit_count,
1055                        cache_miss_count,
1056                        mean_definition_latency,
1057                        max_definition_latency,
1058                    } => {
1059                        debug_tx
1060                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1061                                ContextRetrievalFinishedDebugEvent {
1062                                    project_entity_id: project_entity_id,
1063                                    timestamp: Instant::now(),
1064                                    metadata: vec![
1065                                        (
1066                                            "Cache Hits",
1067                                            format!(
1068                                                "{}/{}",
1069                                                cache_hit_count,
1070                                                cache_hit_count + cache_miss_count
1071                                            )
1072                                            .into(),
1073                                        ),
1074                                        (
1075                                            "Max LSP Time",
1076                                            format!("{} ms", max_definition_latency.as_millis())
1077                                                .into(),
1078                                        ),
1079                                        (
1080                                            "Mean LSP Time",
1081                                            format!("{} ms", mean_definition_latency.as_millis())
1082                                                .into(),
1083                                        ),
1084                                    ],
1085                                },
1086                            ))
1087                            .ok();
1088                    }
1089                }
1090            }
1091        }
1092    }
1093
1094    pub fn debug_info(
1095        &mut self,
1096        project: &Entity<Project>,
1097        cx: &mut Context<Self>,
1098    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1099        let project_state = self.get_or_init_project(project, cx);
1100        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1101        project_state.debug_tx = Some(debug_watch_tx);
1102        debug_watch_rx
1103    }
1104
1105    fn handle_project_event(
1106        &mut self,
1107        project: Entity<Project>,
1108        event: &project::Event,
1109        cx: &mut Context<Self>,
1110    ) {
1111        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1112            return;
1113        }
1114        // TODO [zeta2] init with recent paths
1115        match event {
1116            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1117                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1118                    return;
1119                };
1120                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1121                if let Some(path) = path {
1122                    if let Some(ix) = project_state
1123                        .recent_paths
1124                        .iter()
1125                        .position(|probe| probe == &path)
1126                    {
1127                        project_state.recent_paths.remove(ix);
1128                    }
1129                    project_state.recent_paths.push_front(path);
1130                }
1131            }
1132            project::Event::DiagnosticsUpdated { .. } => {
1133                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1134                    self.refresh_prediction_from_diagnostics(
1135                        project,
1136                        DiagnosticSearchScope::Global,
1137                        cx,
1138                    );
1139                }
1140            }
1141            _ => (),
1142        }
1143    }
1144
1145    fn register_buffer_impl<'a>(
1146        project_state: &'a mut ProjectState,
1147        buffer: &Entity<Buffer>,
1148        project: &Entity<Project>,
1149        cx: &mut Context<Self>,
1150    ) -> &'a mut RegisteredBuffer {
1151        let buffer_id = buffer.entity_id();
1152
1153        if let Some(file) = buffer.read(cx).file() {
1154            let worktree_id = file.worktree_id(cx);
1155            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1156                project_state
1157                    .license_detection_watchers
1158                    .entry(worktree_id)
1159                    .or_insert_with(|| {
1160                        let project_entity_id = project.entity_id();
1161                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1162                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1163                            else {
1164                                return;
1165                            };
1166                            project_state
1167                                .license_detection_watchers
1168                                .remove(&worktree_id);
1169                        })
1170                        .detach();
1171                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1172                    });
1173            }
1174        }
1175
1176        match project_state.registered_buffers.entry(buffer_id) {
1177            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1178            hash_map::Entry::Vacant(entry) => {
1179                let buf = buffer.read(cx);
1180                let snapshot = buf.text_snapshot();
1181                let file = buf.file().cloned();
1182                let project_entity_id = project.entity_id();
1183                entry.insert(RegisteredBuffer {
1184                    snapshot,
1185                    file,
1186                    last_position: None,
1187                    pending_predictions: Vec::new(),
1188                    _subscriptions: [
1189                        cx.subscribe(buffer, {
1190                            let project = project.downgrade();
1191                            move |this, buffer, event, cx| {
1192                                if let language::BufferEvent::Edited = event
1193                                    && let Some(project) = project.upgrade()
1194                                {
1195                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1196                                }
1197                            }
1198                        }),
1199                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1200                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1201                            else {
1202                                return;
1203                            };
1204                            project_state.registered_buffers.remove(&buffer_id);
1205                        }),
1206                    ],
1207                })
1208            }
1209        }
1210    }
1211
1212    fn report_changes_for_buffer(
1213        &mut self,
1214        buffer: &Entity<Buffer>,
1215        project: &Entity<Project>,
1216        is_predicted: bool,
1217        cx: &mut Context<Self>,
1218    ) {
1219        let project_state = self.get_or_init_project(project, cx);
1220        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1221
1222        let buf = buffer.read(cx);
1223        let new_file = buf.file().cloned();
1224        let new_snapshot = buf.text_snapshot();
1225        if new_snapshot.version == registered_buffer.snapshot.version {
1226            return;
1227        }
1228
1229        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1230        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1231        let mut num_edits = 0usize;
1232        let mut total_deleted = 0usize;
1233        let mut total_inserted = 0usize;
1234        let mut edit_range: Option<Range<Anchor>> = None;
1235        let mut last_offset: Option<usize> = None;
1236        let now = cx.background_executor().now();
1237
1238        for (edit, anchor_range) in
1239            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1240        {
1241            num_edits += 1;
1242            total_deleted += edit.old.len();
1243            total_inserted += edit.new.len();
1244            edit_range = Some(match edit_range {
1245                None => anchor_range,
1246                Some(acc) => acc.start..anchor_range.end,
1247            });
1248            last_offset = Some(edit.new.end);
1249        }
1250
1251        let Some(edit_range) = edit_range else {
1252            return;
1253        };
1254
1255        for pending_prediction in &mut registered_buffer.pending_predictions {
1256            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1257                pending_prediction.last_edit_at = now;
1258            }
1259        }
1260
1261        let action_type = match (total_deleted, total_inserted, num_edits) {
1262            (0, ins, n) if ins == n => UserActionType::InsertChar,
1263            (0, _, _) => UserActionType::InsertSelection,
1264            (del, 0, n) if del == n => UserActionType::DeleteChar,
1265            (_, 0, _) => UserActionType::DeleteSelection,
1266            (_, ins, n) if ins == n => UserActionType::InsertChar,
1267            (_, _, _) => UserActionType::InsertSelection,
1268        };
1269
1270        if let Some(offset) = last_offset {
1271            let point = new_snapshot.offset_to_point(offset);
1272            let timestamp_epoch_ms = SystemTime::now()
1273                .duration_since(UNIX_EPOCH)
1274                .map(|d| d.as_millis() as u64)
1275                .unwrap_or(0);
1276            project_state.record_user_action(UserActionRecord {
1277                action_type,
1278                buffer_id: buffer.entity_id(),
1279                line_number: point.row,
1280                offset,
1281                timestamp_epoch_ms,
1282            });
1283        }
1284
1285        let events = &mut project_state.events;
1286
1287        if let Some(last_event) = project_state.last_event.as_mut() {
1288            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1289                == last_event.new_snapshot.remote_id()
1290                && old_snapshot.version == last_event.new_snapshot.version;
1291
1292            let prediction_source_changed = is_predicted != last_event.predicted;
1293
1294            let should_coalesce = is_next_snapshot_of_same_buffer
1295                && !prediction_source_changed
1296                && last_event
1297                    .edit_range
1298                    .as_ref()
1299                    .is_some_and(|last_edit_range| {
1300                        lines_between_ranges(
1301                            &edit_range.to_point(&new_snapshot),
1302                            &last_edit_range.to_point(&new_snapshot),
1303                        ) <= CHANGE_GROUPING_LINE_SPAN
1304                    });
1305
1306            if should_coalesce {
1307                let pause_elapsed = last_event
1308                    .last_edit_time
1309                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1310                    .unwrap_or(false);
1311                if pause_elapsed {
1312                    last_event.snapshot_after_last_editing_pause =
1313                        Some(last_event.new_snapshot.clone());
1314                }
1315
1316                last_event.edit_range = Some(edit_range);
1317                last_event.new_snapshot = new_snapshot;
1318                last_event.last_edit_time = Some(now);
1319                return;
1320            }
1321        }
1322
1323        if let Some(event) = project_state.last_event.take() {
1324            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1325                if events.len() + 1 >= EVENT_COUNT_MAX {
1326                    events.pop_front();
1327                }
1328                events.push_back(event);
1329            }
1330        }
1331
1332        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1333
1334        project_state.last_event = Some(LastEvent {
1335            old_file,
1336            new_file,
1337            old_snapshot,
1338            new_snapshot,
1339            edit_range: Some(edit_range),
1340            predicted: is_predicted,
1341            snapshot_after_last_editing_pause: None,
1342            last_edit_time: Some(now),
1343        });
1344    }
1345
1346    fn prediction_at(
1347        &mut self,
1348        buffer: &Entity<Buffer>,
1349        position: Option<language::Anchor>,
1350        project: &Entity<Project>,
1351        cx: &App,
1352    ) -> Option<BufferEditPrediction<'_>> {
1353        let project_state = self.projects.get_mut(&project.entity_id())?;
1354        if let Some(position) = position
1355            && let Some(buffer) = project_state
1356                .registered_buffers
1357                .get_mut(&buffer.entity_id())
1358        {
1359            buffer.last_position = Some(position);
1360        }
1361
1362        let CurrentEditPrediction {
1363            requested_by,
1364            prediction,
1365            ..
1366        } = project_state.current_prediction.as_ref()?;
1367
1368        if prediction.targets_buffer(buffer.read(cx)) {
1369            Some(BufferEditPrediction::Local { prediction })
1370        } else {
1371            let show_jump = match requested_by {
1372                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1373                    requested_by_buffer_id == &buffer.entity_id()
1374                }
1375                PredictionRequestedBy::DiagnosticsUpdate => true,
1376            };
1377
1378            if show_jump {
1379                Some(BufferEditPrediction::Jump { prediction })
1380            } else {
1381                None
1382            }
1383        }
1384    }
1385
1386    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1387        let Some(current_prediction) = self
1388            .projects
1389            .get_mut(&project.entity_id())
1390            .and_then(|project_state| project_state.current_prediction.take())
1391        else {
1392            return;
1393        };
1394
1395        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1396
1397        // can't hold &mut project_state ref across report_changes_for_buffer_call
1398        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1399            return;
1400        };
1401
1402        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1403            project_state.cancel_pending_prediction(pending_prediction, cx);
1404        }
1405
1406        match self.edit_prediction_model {
1407            EditPredictionModel::Sweep => {
1408                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1409            }
1410            EditPredictionModel::Mercury => {
1411                mercury::edit_prediction_accepted(
1412                    current_prediction.prediction.id,
1413                    self.client.http_client(),
1414                    cx,
1415                );
1416            }
1417            EditPredictionModel::Zeta => {
1418                let is_cloud = !matches!(
1419                    all_language_settings(None, cx).edit_predictions.provider,
1420                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1421                );
1422                if is_cloud {
1423                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1424                }
1425            }
1426            EditPredictionModel::Fim { .. } => {}
1427        }
1428    }
1429
1430    async fn handle_rejected_predictions(
1431        rx: UnboundedReceiver<EditPredictionRejection>,
1432        client: Arc<Client>,
1433        llm_token: LlmApiToken,
1434        app_version: Version,
1435        background_executor: BackgroundExecutor,
1436    ) {
1437        let mut rx = std::pin::pin!(rx.peekable());
1438        let mut batched = Vec::new();
1439
1440        while let Some(rejection) = rx.next().await {
1441            batched.push(rejection);
1442
1443            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1444                select_biased! {
1445                    next = rx.as_mut().peek().fuse() => {
1446                        if next.is_some() {
1447                            continue;
1448                        }
1449                    }
1450                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1451                }
1452            }
1453
1454            let url = client
1455                .http_client()
1456                .build_zed_llm_url("/predict_edits/reject", &[])
1457                .unwrap();
1458
1459            let flush_count = batched
1460                .len()
1461                // in case items have accumulated after failure
1462                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1463            let start = batched.len() - flush_count;
1464
1465            let body = RejectEditPredictionsBodyRef {
1466                rejections: &batched[start..],
1467            };
1468
1469            let result = Self::send_api_request::<()>(
1470                |builder| {
1471                    let req = builder
1472                        .uri(url.as_ref())
1473                        .body(serde_json::to_string(&body)?.into());
1474                    anyhow::Ok(req?)
1475                },
1476                client.clone(),
1477                llm_token.clone(),
1478                app_version.clone(),
1479                true,
1480            )
1481            .await;
1482
1483            if result.log_err().is_some() {
1484                batched.drain(start..);
1485            }
1486        }
1487    }
1488
1489    async fn run_settled_predictions_worker(
1490        this: WeakEntity<Self>,
1491        mut rx: UnboundedReceiver<Instant>,
1492        cx: &mut AsyncApp,
1493    ) {
1494        let mut next_wake_time: Option<Instant> = None;
1495        loop {
1496            let now = cx.background_executor().now();
1497            if let Some(wake_time) = next_wake_time.take() {
1498                cx.background_executor()
1499                    .timer(wake_time.duration_since(now))
1500                    .await;
1501            } else {
1502                let Some(new_enqueue_time) = rx.next().await else {
1503                    break;
1504                };
1505                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1506                while rx.next().now_or_never().flatten().is_some() {}
1507                continue;
1508            }
1509
1510            let Some(this) = this.upgrade() else {
1511                break;
1512            };
1513
1514            let now = cx.background_executor().now();
1515
1516            let mut oldest_edited_at = None;
1517
1518            this.update(cx, |this, _| {
1519                for (_, project_state) in this.projects.iter_mut() {
1520                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1521                        registered_buffer
1522                            .pending_predictions
1523                            .retain_mut(|pending_prediction| {
1524                                let age =
1525                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1526                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1527                                    return false;
1528                                }
1529
1530                                let quiet_for =
1531                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1532                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1533                                    let settled_editable_region = registered_buffer
1534                                        .snapshot
1535                                        .text_for_range(
1536                                            pending_prediction.editable_anchor_range.clone(),
1537                                        )
1538                                        .collect::<String>();
1539
1540                                    #[cfg(test)]
1541                                    if let Some(callback) = &this.settled_event_callback {
1542                                        callback(
1543                                            pending_prediction.request_id.clone(),
1544                                            settled_editable_region.clone(),
1545                                        );
1546                                    }
1547
1548                                    telemetry::event!(
1549                                        EDIT_PREDICTION_SETTLED_EVENT,
1550                                        request_id = pending_prediction.request_id.0.clone(),
1551                                        settled_editable_region,
1552                                    );
1553
1554                                    return false;
1555                                }
1556
1557                                if oldest_edited_at
1558                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1559                                {
1560                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1561                                }
1562
1563                                true
1564                            });
1565                    }
1566                }
1567            });
1568
1569            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1570        }
1571    }
1572
1573    pub(crate) fn enqueue_settled_prediction(
1574        &mut self,
1575        request_id: EditPredictionId,
1576        project: &Entity<Project>,
1577        edited_buffer: &Entity<Buffer>,
1578        edited_buffer_snapshot: &BufferSnapshot,
1579        editable_offset_range: Range<usize>,
1580        cx: &mut Context<Self>,
1581    ) {
1582        let project_state = self.get_or_init_project(project, cx);
1583        if let Some(buffer) = project_state
1584            .registered_buffers
1585            .get_mut(&edited_buffer.entity_id())
1586        {
1587            let now = cx.background_executor().now();
1588            buffer.pending_predictions.push(PendingSettledPrediction {
1589                request_id,
1590                editable_anchor_range: edited_buffer_snapshot
1591                    .anchor_range_around(editable_offset_range),
1592                enqueued_at: now,
1593                last_edit_at: now,
1594            });
1595            self.settled_predictions_tx.unbounded_send(now).ok();
1596        }
1597    }
1598
1599    fn reject_current_prediction(
1600        &mut self,
1601        reason: EditPredictionRejectReason,
1602        project: &Entity<Project>,
1603        cx: &App,
1604    ) {
1605        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1606            project_state.pending_predictions.clear();
1607            if let Some(prediction) = project_state.current_prediction.take() {
1608                let model_version = prediction.prediction.model_version.clone();
1609                self.reject_prediction(
1610                    prediction.prediction.id,
1611                    reason,
1612                    prediction.was_shown,
1613                    model_version,
1614                    cx,
1615                );
1616            }
1617        };
1618    }
1619
1620    fn did_show_current_prediction(
1621        &mut self,
1622        project: &Entity<Project>,
1623        display_type: edit_prediction_types::SuggestionDisplayType,
1624        cx: &mut Context<Self>,
1625    ) {
1626        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1627            return;
1628        };
1629
1630        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1631            return;
1632        };
1633
1634        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1635        let previous_shown_with = current_prediction.shown_with;
1636
1637        if previous_shown_with.is_none() || !is_jump {
1638            current_prediction.shown_with = Some(display_type);
1639        }
1640
1641        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1642
1643        if is_first_non_jump_show {
1644            current_prediction.was_shown = true;
1645        }
1646
1647        let display_type_changed = previous_shown_with != Some(display_type);
1648
1649        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1650            sweep_ai::edit_prediction_shown(
1651                &self.sweep_ai,
1652                self.client.clone(),
1653                &current_prediction.prediction,
1654                display_type,
1655                cx,
1656            );
1657        }
1658
1659        if is_first_non_jump_show {
1660            self.shown_predictions
1661                .push_front(current_prediction.prediction.clone());
1662            if self.shown_predictions.len() > 50 {
1663                let completion = self.shown_predictions.pop_back().unwrap();
1664                self.rated_predictions.remove(&completion.id);
1665            }
1666        }
1667    }
1668
1669    fn reject_prediction(
1670        &mut self,
1671        prediction_id: EditPredictionId,
1672        reason: EditPredictionRejectReason,
1673        was_shown: bool,
1674        model_version: Option<String>,
1675        cx: &App,
1676    ) {
1677        match self.edit_prediction_model {
1678            EditPredictionModel::Zeta => {
1679                let is_cloud = !matches!(
1680                    all_language_settings(None, cx).edit_predictions.provider,
1681                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1682                );
1683                if is_cloud {
1684                    self.reject_predictions_tx
1685                        .unbounded_send(EditPredictionRejection {
1686                            request_id: prediction_id.to_string(),
1687                            reason,
1688                            was_shown,
1689                            model_version,
1690                        })
1691                        .log_err();
1692                }
1693            }
1694            EditPredictionModel::Mercury => {
1695                mercury::edit_prediction_rejected(
1696                    prediction_id,
1697                    was_shown,
1698                    reason,
1699                    self.client.http_client(),
1700                    cx,
1701                );
1702            }
1703            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1704        }
1705    }
1706
1707    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1708        self.projects
1709            .get(&project.entity_id())
1710            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1711    }
1712
1713    pub fn refresh_prediction_from_buffer(
1714        &mut self,
1715        project: Entity<Project>,
1716        buffer: Entity<Buffer>,
1717        position: language::Anchor,
1718        cx: &mut Context<Self>,
1719    ) {
1720        self.queue_prediction_refresh(
1721            project.clone(),
1722            PredictEditsRequestTrigger::Other,
1723            buffer.entity_id(),
1724            cx,
1725            move |this, cx| {
1726                let Some(request_task) = this
1727                    .update(cx, |this, cx| {
1728                        this.request_prediction(
1729                            &project,
1730                            &buffer,
1731                            position,
1732                            PredictEditsRequestTrigger::Other,
1733                            cx,
1734                        )
1735                    })
1736                    .log_err()
1737                else {
1738                    return Task::ready(anyhow::Ok(None));
1739                };
1740
1741                cx.spawn(async move |_cx| {
1742                    request_task.await.map(|prediction_result| {
1743                        prediction_result.map(|prediction_result| {
1744                            (
1745                                prediction_result,
1746                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1747                            )
1748                        })
1749                    })
1750                })
1751            },
1752        )
1753    }
1754
1755    pub fn refresh_prediction_from_diagnostics(
1756        &mut self,
1757        project: Entity<Project>,
1758        scope: DiagnosticSearchScope,
1759        cx: &mut Context<Self>,
1760    ) {
1761        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1762            return;
1763        }
1764
1765        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1766            return;
1767        };
1768
1769        // Prefer predictions from buffer
1770        if project_state.current_prediction.is_some() {
1771            return;
1772        }
1773
1774        self.queue_prediction_refresh(
1775            project.clone(),
1776            PredictEditsRequestTrigger::Diagnostics,
1777            project.entity_id(),
1778            cx,
1779            move |this, cx| {
1780                let Some((active_buffer, snapshot, cursor_point)) = this
1781                    .read_with(cx, |this, cx| {
1782                        let project_state = this.projects.get(&project.entity_id())?;
1783                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1784                        let snapshot = buffer.read(cx).snapshot();
1785
1786                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1787                            return None;
1788                        }
1789
1790                        let cursor_point = position
1791                            .map(|pos| pos.to_point(&snapshot))
1792                            .unwrap_or_default();
1793
1794                        Some((buffer, snapshot, cursor_point))
1795                    })
1796                    .log_err()
1797                    .flatten()
1798                else {
1799                    return Task::ready(anyhow::Ok(None));
1800                };
1801
1802                cx.spawn(async move |cx| {
1803                    let diagnostic_search_range = match scope {
1804                        DiagnosticSearchScope::Local => {
1805                            let diagnostic_search_start =
1806                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1807                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1808                            Point::new(diagnostic_search_start, 0)
1809                                ..Point::new(diagnostic_search_end, 0)
1810                        }
1811                        DiagnosticSearchScope::Global => Default::default(),
1812                    };
1813
1814                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1815                        active_buffer,
1816                        &snapshot,
1817                        diagnostic_search_range,
1818                        cursor_point,
1819                        &project,
1820                        cx,
1821                    )
1822                    .await?
1823                    else {
1824                        return anyhow::Ok(None);
1825                    };
1826
1827                    let Some(prediction_result) = this
1828                        .update(cx, |this, cx| {
1829                            this.request_prediction(
1830                                &project,
1831                                &jump_buffer,
1832                                jump_position,
1833                                PredictEditsRequestTrigger::Diagnostics,
1834                                cx,
1835                            )
1836                        })?
1837                        .await?
1838                    else {
1839                        return anyhow::Ok(None);
1840                    };
1841
1842                    this.update(cx, |this, cx| {
1843                        Some((
1844                            if this
1845                                .get_or_init_project(&project, cx)
1846                                .current_prediction
1847                                .is_none()
1848                            {
1849                                prediction_result
1850                            } else {
1851                                EditPredictionResult {
1852                                    id: prediction_result.id,
1853                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1854                                }
1855                            },
1856                            PredictionRequestedBy::DiagnosticsUpdate,
1857                        ))
1858                    })
1859                })
1860            },
1861        );
1862    }
1863
1864    fn predictions_enabled_at(
1865        snapshot: &BufferSnapshot,
1866        position: Option<language::Anchor>,
1867        cx: &App,
1868    ) -> bool {
1869        let file = snapshot.file();
1870        let all_settings = all_language_settings(file, cx);
1871        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1872            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1873        {
1874            return false;
1875        }
1876
1877        if let Some(last_position) = position {
1878            let settings = snapshot.settings_at(last_position, cx);
1879
1880            if !settings.edit_predictions_disabled_in.is_empty()
1881                && let Some(scope) = snapshot.language_scope_at(last_position)
1882                && let Some(scope_name) = scope.override_name()
1883                && settings
1884                    .edit_predictions_disabled_in
1885                    .iter()
1886                    .any(|s| s == scope_name)
1887            {
1888                return false;
1889            }
1890        }
1891
1892        true
1893    }
1894
1895    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1896}
1897
1898fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1899    match provider {
1900        EditPredictionProvider::Zed
1901        | EditPredictionProvider::Sweep
1902        | EditPredictionProvider::Mercury
1903        | EditPredictionProvider::Ollama
1904        | EditPredictionProvider::OpenAiCompatibleApi
1905        | EditPredictionProvider::Experimental(_) => true,
1906        EditPredictionProvider::None
1907        | EditPredictionProvider::Copilot
1908        | EditPredictionProvider::Codestral => false,
1909    }
1910}
1911
1912impl EditPredictionStore {
1913    fn queue_prediction_refresh(
1914        &mut self,
1915        project: Entity<Project>,
1916        request_trigger: PredictEditsRequestTrigger,
1917        throttle_entity: EntityId,
1918        cx: &mut Context<Self>,
1919        do_refresh: impl FnOnce(
1920            WeakEntity<Self>,
1921            &mut AsyncApp,
1922        )
1923            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1924        + 'static,
1925    ) {
1926        fn select_throttle(
1927            project_state: &mut ProjectState,
1928            request_trigger: PredictEditsRequestTrigger,
1929        ) -> &mut Option<(EntityId, Instant)> {
1930            match request_trigger {
1931                PredictEditsRequestTrigger::Diagnostics => {
1932                    &mut project_state.last_jump_prediction_refresh
1933                }
1934                _ => &mut project_state.last_edit_prediction_refresh,
1935            }
1936        }
1937
1938        let (needs_acceptance_tracking, max_pending_predictions) =
1939            match all_language_settings(None, cx).edit_predictions.provider {
1940                EditPredictionProvider::Zed
1941                | EditPredictionProvider::Sweep
1942                | EditPredictionProvider::Mercury
1943                | EditPredictionProvider::Experimental(_) => (true, 2),
1944                EditPredictionProvider::Ollama => (false, 1),
1945                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1946                EditPredictionProvider::None
1947                | EditPredictionProvider::Copilot
1948                | EditPredictionProvider::Codestral => {
1949                    log::error!("queue_prediction_refresh called with non-store provider");
1950                    return;
1951                }
1952            };
1953
1954        let drop_on_cancel = !needs_acceptance_tracking;
1955        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1956        let project_state = self.get_or_init_project(&project, cx);
1957        let pending_prediction_id = project_state.next_pending_prediction_id;
1958        project_state.next_pending_prediction_id += 1;
1959        let last_request = *select_throttle(project_state, request_trigger);
1960
1961        let task = cx.spawn(async move |this, cx| {
1962            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1963                if throttle_entity != last_entity {
1964                    return None;
1965                }
1966                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1967            }) {
1968                cx.background_executor().timer(timeout).await;
1969            }
1970
1971            // If this task was cancelled before the throttle timeout expired,
1972            // do not perform a request.
1973            let mut is_cancelled = true;
1974            this.update(cx, |this, cx| {
1975                let project_state = this.get_or_init_project(&project, cx);
1976                let was_cancelled = project_state
1977                    .cancelled_predictions
1978                    .remove(&pending_prediction_id);
1979                if !was_cancelled {
1980                    let new_refresh = (throttle_entity, Instant::now());
1981                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
1982                    is_cancelled = false;
1983                }
1984            })
1985            .ok();
1986            if is_cancelled {
1987                return None;
1988            }
1989
1990            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1991            let new_prediction_id = new_prediction_result
1992                .as_ref()
1993                .map(|(prediction, _)| prediction.id.clone());
1994
1995            // When a prediction completes, remove it from the pending list, and cancel
1996            // any pending predictions that were enqueued before it.
1997            this.update(cx, |this, cx| {
1998                let project_state = this.get_or_init_project(&project, cx);
1999
2000                let is_cancelled = project_state
2001                    .cancelled_predictions
2002                    .remove(&pending_prediction_id);
2003
2004                let new_current_prediction = if !is_cancelled
2005                    && let Some((prediction_result, requested_by)) = new_prediction_result
2006                {
2007                    match prediction_result.prediction {
2008                        Ok(prediction) => {
2009                            let new_prediction = CurrentEditPrediction {
2010                                requested_by,
2011                                prediction,
2012                                was_shown: false,
2013                                shown_with: None,
2014                            };
2015
2016                            if let Some(current_prediction) =
2017                                project_state.current_prediction.as_ref()
2018                            {
2019                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2020                                {
2021                                    this.reject_current_prediction(
2022                                        EditPredictionRejectReason::Replaced,
2023                                        &project,
2024                                        cx,
2025                                    );
2026
2027                                    Some(new_prediction)
2028                                } else {
2029                                    this.reject_prediction(
2030                                        new_prediction.prediction.id,
2031                                        EditPredictionRejectReason::CurrentPreferred,
2032                                        false,
2033                                        new_prediction.prediction.model_version,
2034                                        cx,
2035                                    );
2036                                    None
2037                                }
2038                            } else {
2039                                Some(new_prediction)
2040                            }
2041                        }
2042                        Err(reject_reason) => {
2043                            this.reject_prediction(
2044                                prediction_result.id,
2045                                reject_reason,
2046                                false,
2047                                None,
2048                                cx,
2049                            );
2050                            None
2051                        }
2052                    }
2053                } else {
2054                    None
2055                };
2056
2057                let project_state = this.get_or_init_project(&project, cx);
2058
2059                if let Some(new_prediction) = new_current_prediction {
2060                    project_state.current_prediction = Some(new_prediction);
2061                }
2062
2063                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2064                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2065                    if pending_prediction.id == pending_prediction_id {
2066                        pending_predictions.remove(ix);
2067                        for pending_prediction in pending_predictions.drain(0..ix) {
2068                            project_state.cancel_pending_prediction(pending_prediction, cx)
2069                        }
2070                        break;
2071                    }
2072                }
2073                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2074                cx.notify();
2075            })
2076            .ok();
2077
2078            new_prediction_id
2079        });
2080
2081        if project_state.pending_predictions.len() < max_pending_predictions {
2082            project_state.pending_predictions.push(PendingPrediction {
2083                id: pending_prediction_id,
2084                task,
2085                drop_on_cancel,
2086            });
2087        } else {
2088            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2089            project_state.pending_predictions.push(PendingPrediction {
2090                id: pending_prediction_id,
2091                task,
2092                drop_on_cancel,
2093            });
2094            project_state.cancel_pending_prediction(pending_prediction, cx);
2095        }
2096    }
2097
2098    pub fn request_prediction(
2099        &mut self,
2100        project: &Entity<Project>,
2101        active_buffer: &Entity<Buffer>,
2102        position: language::Anchor,
2103        trigger: PredictEditsRequestTrigger,
2104        cx: &mut Context<Self>,
2105    ) -> Task<Result<Option<EditPredictionResult>>> {
2106        self.request_prediction_internal(
2107            project.clone(),
2108            active_buffer.clone(),
2109            position,
2110            trigger,
2111            cx.has_flag::<Zeta2FeatureFlag>(),
2112            cx,
2113        )
2114    }
2115
2116    fn request_prediction_internal(
2117        &mut self,
2118        project: Entity<Project>,
2119        active_buffer: Entity<Buffer>,
2120        position: language::Anchor,
2121        trigger: PredictEditsRequestTrigger,
2122        allow_jump: bool,
2123        cx: &mut Context<Self>,
2124    ) -> Task<Result<Option<EditPredictionResult>>> {
2125        self.get_or_init_project(&project, cx);
2126        let project_state = self.projects.get(&project.entity_id()).unwrap();
2127        let stored_events = project_state.events(cx);
2128        let has_events = !stored_events.is_empty();
2129        let events: Vec<Arc<zeta_prompt::Event>> =
2130            stored_events.iter().map(|e| e.event.clone()).collect();
2131        let debug_tx = project_state.debug_tx.clone();
2132
2133        let snapshot = active_buffer.read(cx).snapshot();
2134        let cursor_point = position.to_point(&snapshot);
2135        let current_offset = position.to_offset(&snapshot);
2136
2137        let mut user_actions: Vec<UserActionRecord> =
2138            project_state.user_actions.iter().cloned().collect();
2139
2140        if let Some(last_action) = user_actions.last() {
2141            if last_action.buffer_id == active_buffer.entity_id()
2142                && current_offset != last_action.offset
2143            {
2144                let timestamp_epoch_ms = SystemTime::now()
2145                    .duration_since(UNIX_EPOCH)
2146                    .map(|d| d.as_millis() as u64)
2147                    .unwrap_or(0);
2148                user_actions.push(UserActionRecord {
2149                    action_type: UserActionType::CursorMovement,
2150                    buffer_id: active_buffer.entity_id(),
2151                    line_number: cursor_point.row,
2152                    offset: current_offset,
2153                    timestamp_epoch_ms,
2154                });
2155            }
2156        }
2157        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2158        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2159        let diagnostic_search_range =
2160            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2161
2162        let related_files = self.context_for_project(&project, cx);
2163
2164        let is_open_source = snapshot
2165            .file()
2166            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2167            && events.iter().all(|event| event.in_open_source_repo())
2168            && related_files.iter().all(|file| file.in_open_source_repo);
2169
2170        let can_collect_data = !cfg!(test)
2171            && is_open_source
2172            && self.is_data_collection_enabled(cx)
2173            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2174
2175        let inputs = EditPredictionModelInput {
2176            project: project.clone(),
2177            buffer: active_buffer.clone(),
2178            snapshot: snapshot,
2179            position,
2180            events,
2181            related_files,
2182            recent_paths: project_state.recent_paths.clone(),
2183            trigger,
2184            diagnostic_search_range: diagnostic_search_range,
2185            debug_tx,
2186            user_actions,
2187            can_collect_data,
2188            is_open_source,
2189        };
2190
2191        if can_collect_data && rand::random_ratio(1, 1000) {
2192            if let Some(task) = capture_example(
2193                project.clone(),
2194                active_buffer,
2195                position,
2196                stored_events,
2197                false,
2198                cx,
2199            ) {
2200                task.detach();
2201            }
2202        }
2203
2204        let task = match self.edit_prediction_model {
2205            EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx),
2206            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2207            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2208            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2209        };
2210
2211        cx.spawn(async move |this, cx| {
2212            let prediction = task.await?;
2213
2214            if prediction.is_none() && allow_jump && has_events {
2215                this.update(cx, |this, cx| {
2216                    this.refresh_prediction_from_diagnostics(
2217                        project,
2218                        DiagnosticSearchScope::Local,
2219                        cx,
2220                    );
2221                })?;
2222                return anyhow::Ok(None);
2223            }
2224
2225            Ok(prediction)
2226        })
2227    }
2228
2229    pub(crate) async fn next_diagnostic_location(
2230        active_buffer: Entity<Buffer>,
2231        active_buffer_snapshot: &BufferSnapshot,
2232        active_buffer_diagnostic_search_range: Range<Point>,
2233        active_buffer_cursor_point: Point,
2234        project: &Entity<Project>,
2235        cx: &mut AsyncApp,
2236    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2237        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2238            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2239            .flat_map(|(_, _, _, selections)| {
2240                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2241            })
2242            .collect();
2243
2244        let mut jump_location = active_buffer_snapshot
2245            .diagnostic_groups(None)
2246            .into_iter()
2247            .filter_map(|(_, group)| {
2248                let range = &group.entries[group.primary_ix]
2249                    .range
2250                    .to_point(&active_buffer_snapshot);
2251                if range.overlaps(&active_buffer_diagnostic_search_range) {
2252                    return None;
2253                }
2254                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2255                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2256                });
2257                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2258                    <= DIAGNOSTIC_LINES_RANGE;
2259                if near_collaborator && !near_local {
2260                    return None;
2261                }
2262                Some(range.start)
2263            })
2264            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2265            .map(|position| {
2266                (
2267                    active_buffer.clone(),
2268                    active_buffer_snapshot.anchor_before(position),
2269                )
2270            });
2271
2272        if jump_location.is_none() {
2273            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2274                let file = buffer.file()?;
2275
2276                Some(ProjectPath {
2277                    worktree_id: file.worktree_id(cx),
2278                    path: file.path().clone(),
2279                })
2280            });
2281
2282            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2283                project
2284                    .diagnostic_summaries(false, cx)
2285                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2286                    .map(|(path, _, _)| {
2287                        let shared_prefix = path
2288                            .path
2289                            .components()
2290                            .zip(
2291                                active_buffer_path
2292                                    .as_ref()
2293                                    .map(|p| p.path.components())
2294                                    .unwrap_or_default(),
2295                            )
2296                            .take_while(|(a, b)| a == b)
2297                            .count();
2298                        (path, shared_prefix)
2299                    })
2300                    .collect()
2301            });
2302
2303            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2304
2305            for (path, _) in candidates {
2306                let candidate_buffer = project
2307                    .update(cx, |project, cx| project.open_buffer(path, cx))
2308                    .await?;
2309
2310                let (has_collaborators, diagnostic_position) =
2311                    candidate_buffer.read_with(cx, |buffer, _cx| {
2312                        let snapshot = buffer.snapshot();
2313                        let has_collaborators = snapshot
2314                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2315                            .next()
2316                            .is_some();
2317                        let position = buffer
2318                            .buffer_diagnostics(None)
2319                            .into_iter()
2320                            .min_by_key(|entry| entry.diagnostic.severity)
2321                            .map(|entry| entry.range.start);
2322                        (has_collaborators, position)
2323                    });
2324
2325                if has_collaborators {
2326                    continue;
2327                }
2328
2329                if let Some(position) = diagnostic_position {
2330                    jump_location = Some((candidate_buffer, position));
2331                    break;
2332                }
2333            }
2334        }
2335
2336        anyhow::Ok(jump_location)
2337    }
2338
2339    async fn send_raw_llm_request(
2340        request: RawCompletionRequest,
2341        client: Arc<Client>,
2342        custom_url: Option<Arc<Url>>,
2343        llm_token: LlmApiToken,
2344        app_version: Version,
2345    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2346        let url = if let Some(custom_url) = custom_url {
2347            custom_url.as_ref().clone()
2348        } else {
2349            client
2350                .http_client()
2351                .build_zed_llm_url("/predict_edits/raw", &[])?
2352        };
2353
2354        Self::send_api_request(
2355            |builder| {
2356                let req = builder
2357                    .uri(url.as_ref())
2358                    .body(serde_json::to_string(&request)?.into());
2359                Ok(req?)
2360            },
2361            client,
2362            llm_token,
2363            app_version,
2364            true,
2365        )
2366        .await
2367    }
2368
2369    pub(crate) async fn send_v3_request(
2370        input: ZetaPromptInput,
2371        client: Arc<Client>,
2372        llm_token: LlmApiToken,
2373        app_version: Version,
2374        trigger: PredictEditsRequestTrigger,
2375    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2376        let url = client
2377            .http_client()
2378            .build_zed_llm_url("/predict_edits/v3", &[])?;
2379
2380        let request = PredictEditsV3Request { input, trigger };
2381
2382        let json_bytes = serde_json::to_vec(&request)?;
2383        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2384
2385        Self::send_api_request(
2386            |builder| {
2387                let req = builder
2388                    .uri(url.as_ref())
2389                    .header("Content-Encoding", "zstd")
2390                    .body(compressed.clone().into());
2391                Ok(req?)
2392            },
2393            client,
2394            llm_token,
2395            app_version,
2396            true,
2397        )
2398        .await
2399    }
2400
2401    fn handle_api_response<T>(
2402        this: &WeakEntity<Self>,
2403        response: Result<(T, Option<EditPredictionUsage>)>,
2404        cx: &mut gpui::AsyncApp,
2405    ) -> Result<T> {
2406        match response {
2407            Ok((data, usage)) => {
2408                if let Some(usage) = usage {
2409                    this.update(cx, |this, cx| {
2410                        this.user_store.update(cx, |user_store, cx| {
2411                            user_store.update_edit_prediction_usage(usage, cx);
2412                        });
2413                    })
2414                    .ok();
2415                }
2416                Ok(data)
2417            }
2418            Err(err) => {
2419                if err.is::<ZedUpdateRequiredError>() {
2420                    cx.update(|cx| {
2421                        this.update(cx, |this, _cx| {
2422                            this.update_required = true;
2423                        })
2424                        .ok();
2425
2426                        let error_message: SharedString = err.to_string().into();
2427                        show_app_notification(
2428                            NotificationId::unique::<ZedUpdateRequiredError>(),
2429                            cx,
2430                            move |cx| {
2431                                cx.new(|cx| {
2432                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2433                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2434                                })
2435                            },
2436                        );
2437                    });
2438                }
2439                Err(err)
2440            }
2441        }
2442    }
2443
2444    async fn send_api_request<Res>(
2445        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2446        client: Arc<Client>,
2447        llm_token: LlmApiToken,
2448        app_version: Version,
2449        require_auth: bool,
2450    ) -> Result<(Res, Option<EditPredictionUsage>)>
2451    where
2452        Res: DeserializeOwned,
2453    {
2454        let http_client = client.http_client();
2455
2456        let mut token = if require_auth {
2457            Some(llm_token.acquire(&client).await?)
2458        } else {
2459            llm_token.acquire(&client).await.ok()
2460        };
2461        let mut did_retry = false;
2462
2463        loop {
2464            let request_builder = http_client::Request::builder().method(Method::POST);
2465
2466            let mut request_builder = request_builder
2467                .header("Content-Type", "application/json")
2468                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2469
2470            // Only add Authorization header if we have a token
2471            if let Some(ref token_value) = token {
2472                request_builder =
2473                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2474            }
2475
2476            let request = build(request_builder)?;
2477
2478            let mut response = http_client.send(request).await?;
2479
2480            if let Some(minimum_required_version) = response
2481                .headers()
2482                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2483                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2484            {
2485                anyhow::ensure!(
2486                    app_version >= minimum_required_version,
2487                    ZedUpdateRequiredError {
2488                        minimum_version: minimum_required_version
2489                    }
2490                );
2491            }
2492
2493            if response.status().is_success() {
2494                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2495
2496                let mut body = Vec::new();
2497                response.body_mut().read_to_end(&mut body).await?;
2498                return Ok((serde_json::from_slice(&body)?, usage));
2499            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2500                did_retry = true;
2501                token = Some(llm_token.refresh(&client).await?);
2502            } else {
2503                let mut body = String::new();
2504                response.body_mut().read_to_string(&mut body).await?;
2505                anyhow::bail!(
2506                    "Request failed with status: {:?}\nBody: {}",
2507                    response.status(),
2508                    body
2509                );
2510            }
2511        }
2512    }
2513
2514    pub fn refresh_context(
2515        &mut self,
2516        project: &Entity<Project>,
2517        buffer: &Entity<language::Buffer>,
2518        cursor_position: language::Anchor,
2519        cx: &mut Context<Self>,
2520    ) {
2521        self.get_or_init_project(project, cx)
2522            .context
2523            .update(cx, |store, cx| {
2524                store.refresh(buffer.clone(), cursor_position, cx);
2525            });
2526    }
2527
2528    #[cfg(feature = "cli-support")]
2529    pub fn set_context_for_buffer(
2530        &mut self,
2531        project: &Entity<Project>,
2532        related_files: Vec<RelatedFile>,
2533        cx: &mut Context<Self>,
2534    ) {
2535        self.get_or_init_project(project, cx)
2536            .context
2537            .update(cx, |store, cx| {
2538                store.set_related_files(related_files, cx);
2539            });
2540    }
2541
2542    #[cfg(feature = "cli-support")]
2543    pub fn set_recent_paths_for_project(
2544        &mut self,
2545        project: &Entity<Project>,
2546        paths: impl IntoIterator<Item = project::ProjectPath>,
2547        cx: &mut Context<Self>,
2548    ) {
2549        let project_state = self.get_or_init_project(project, cx);
2550        project_state.recent_paths = paths.into_iter().collect();
2551    }
2552
2553    fn is_file_open_source(
2554        &self,
2555        project: &Entity<Project>,
2556        file: &Arc<dyn File>,
2557        cx: &App,
2558    ) -> bool {
2559        if !file.is_local() || file.is_private() {
2560            return false;
2561        }
2562        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2563            return false;
2564        };
2565        project_state
2566            .license_detection_watchers
2567            .get(&file.worktree_id(cx))
2568            .as_ref()
2569            .is_some_and(|watcher| watcher.is_project_open_source())
2570    }
2571
2572    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2573        self.data_collection_choice.is_enabled(cx)
2574    }
2575
2576    fn load_data_collection_choice() -> DataCollectionChoice {
2577        let choice = KEY_VALUE_STORE
2578            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2579            .log_err()
2580            .flatten();
2581
2582        match choice.as_deref() {
2583            Some("true") => DataCollectionChoice::Enabled,
2584            Some("false") => DataCollectionChoice::Disabled,
2585            Some(_) => {
2586                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2587                DataCollectionChoice::NotAnswered
2588            }
2589            None => DataCollectionChoice::NotAnswered,
2590        }
2591    }
2592
2593    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2594        self.data_collection_choice = self.data_collection_choice.toggle();
2595        let new_choice = self.data_collection_choice;
2596        let is_enabled = new_choice.is_enabled(cx);
2597        db::write_and_log(cx, move || {
2598            KEY_VALUE_STORE.write_kvp(
2599                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2600                is_enabled.to_string(),
2601            )
2602        });
2603    }
2604
2605    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2606        self.shown_predictions.iter()
2607    }
2608
2609    pub fn shown_completions_len(&self) -> usize {
2610        self.shown_predictions.len()
2611    }
2612
2613    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2614        self.rated_predictions.contains(id)
2615    }
2616
2617    pub fn rate_prediction(
2618        &mut self,
2619        prediction: &EditPrediction,
2620        rating: EditPredictionRating,
2621        feedback: String,
2622        cx: &mut Context<Self>,
2623    ) {
2624        let organization = self.user_store.read(cx).current_organization();
2625
2626        self.rated_predictions.insert(prediction.id.clone());
2627
2628        cx.background_spawn({
2629            let client = self.client.clone();
2630            let prediction_id = prediction.id.to_string();
2631            let inputs = serde_json::to_value(&prediction.inputs);
2632            let output = prediction
2633                .edit_preview
2634                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2635            async move {
2636                client
2637                    .cloud_client()
2638                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2639                        organization_id: organization.map(|organization| organization.id.clone()),
2640                        request_id: prediction_id,
2641                        rating: match rating {
2642                            EditPredictionRating::Positive => "positive".to_string(),
2643                            EditPredictionRating::Negative => "negative".to_string(),
2644                        },
2645                        inputs: inputs?,
2646                        output,
2647                        feedback,
2648                    })
2649                    .await?;
2650
2651                anyhow::Ok(())
2652            }
2653        })
2654        .detach_and_log_err(cx);
2655
2656        cx.notify();
2657    }
2658}
2659
2660fn merge_trailing_events_if_needed(
2661    events: &mut VecDeque<StoredEvent>,
2662    end_snapshot: &TextBufferSnapshot,
2663    latest_snapshot: &TextBufferSnapshot,
2664    latest_edit_range: &Range<Anchor>,
2665) {
2666    if let Some(last_event) = events.back() {
2667        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2668            return;
2669        }
2670    }
2671
2672    let mut next_old_event = None;
2673    let mut mergeable_count = 0;
2674    for old_event in events.iter().rev() {
2675        if let Some(next_old_event) = &next_old_event
2676            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2677        {
2678            break;
2679        }
2680        mergeable_count += 1;
2681        next_old_event = Some(old_event);
2682    }
2683
2684    if mergeable_count <= 1 {
2685        return;
2686    }
2687
2688    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2689    let oldest_event = events_to_merge.peek().unwrap();
2690    let oldest_snapshot = oldest_event.old_snapshot.clone();
2691
2692    if let Some((diff, edited_range)) =
2693        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2694    {
2695        let merged_event = match oldest_event.event.as_ref() {
2696            zeta_prompt::Event::BufferChange {
2697                old_path,
2698                path,
2699                in_open_source_repo,
2700                ..
2701            } => StoredEvent {
2702                event: Arc::new(zeta_prompt::Event::BufferChange {
2703                    old_path: old_path.clone(),
2704                    path: path.clone(),
2705                    diff,
2706                    in_open_source_repo: *in_open_source_repo,
2707                    predicted: events_to_merge.all(|e| {
2708                        matches!(
2709                            e.event.as_ref(),
2710                            zeta_prompt::Event::BufferChange {
2711                                predicted: true,
2712                                ..
2713                            }
2714                        )
2715                    }),
2716                }),
2717                old_snapshot: oldest_snapshot.clone(),
2718                edit_range: end_snapshot.anchor_before(edited_range.start)
2719                    ..end_snapshot.anchor_before(edited_range.end),
2720            },
2721        };
2722        events.truncate(events.len() - mergeable_count);
2723        events.push_back(merged_event);
2724    }
2725}
2726
2727pub(crate) fn filter_redundant_excerpts(
2728    mut related_files: Vec<RelatedFile>,
2729    cursor_path: &Path,
2730    cursor_row_range: Range<u32>,
2731) -> Vec<RelatedFile> {
2732    for file in &mut related_files {
2733        if file.path.as_ref() == cursor_path {
2734            file.excerpts.retain(|excerpt| {
2735                excerpt.row_range.start < cursor_row_range.start
2736                    || excerpt.row_range.end > cursor_row_range.end
2737            });
2738        }
2739    }
2740    related_files.retain(|file| !file.excerpts.is_empty());
2741    related_files
2742}
2743
2744#[derive(Error, Debug)]
2745#[error(
2746    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2747)]
2748pub struct ZedUpdateRequiredError {
2749    minimum_version: Version,
2750}
2751
2752#[derive(Debug, Clone, Copy)]
2753pub enum DataCollectionChoice {
2754    NotAnswered,
2755    Enabled,
2756    Disabled,
2757}
2758
2759impl DataCollectionChoice {
2760    pub fn is_enabled(self, cx: &App) -> bool {
2761        if cx.is_staff() {
2762            return true;
2763        }
2764        match self {
2765            Self::Enabled => true,
2766            Self::NotAnswered | Self::Disabled => false,
2767        }
2768    }
2769
2770    #[must_use]
2771    pub fn toggle(&self) -> DataCollectionChoice {
2772        match self {
2773            Self::Enabled => Self::Disabled,
2774            Self::Disabled => Self::Enabled,
2775            Self::NotAnswered => Self::Enabled,
2776        }
2777    }
2778}
2779
2780impl From<bool> for DataCollectionChoice {
2781    fn from(value: bool) -> Self {
2782        match value {
2783            true => DataCollectionChoice::Enabled,
2784            false => DataCollectionChoice::Disabled,
2785        }
2786    }
2787}
2788
2789struct ZedPredictUpsell;
2790
2791impl Dismissable for ZedPredictUpsell {
2792    const KEY: &'static str = "dismissed-edit-predict-upsell";
2793
2794    fn dismissed() -> bool {
2795        // To make this backwards compatible with older versions of Zed, we
2796        // check if the user has seen the previous Edit Prediction Onboarding
2797        // before, by checking the data collection choice which was written to
2798        // the database once the user clicked on "Accept and Enable"
2799        if KEY_VALUE_STORE
2800            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2801            .log_err()
2802            .is_some_and(|s| s.is_some())
2803        {
2804            return true;
2805        }
2806
2807        KEY_VALUE_STORE
2808            .read_kvp(Self::KEY)
2809            .log_err()
2810            .is_some_and(|s| s.is_some())
2811    }
2812}
2813
2814pub fn should_show_upsell_modal() -> bool {
2815    !ZedPredictUpsell::dismissed()
2816}
2817
2818pub fn init(cx: &mut App) {
2819    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2820        workspace.register_action(
2821            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2822                ZedPredictModal::toggle(
2823                    workspace,
2824                    workspace.user_store().clone(),
2825                    workspace.client().clone(),
2826                    window,
2827                    cx,
2828                )
2829            },
2830        );
2831
2832        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2833            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2834                settings
2835                    .project
2836                    .all_languages
2837                    .edit_predictions
2838                    .get_or_insert_default()
2839                    .provider = Some(EditPredictionProvider::None)
2840            });
2841        });
2842        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2843            EditPredictionStore::try_global(cx).and_then(|store| {
2844                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2845            })
2846        }
2847
2848        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2849            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2850                copilot_ui::initiate_sign_in(copilot, window, cx);
2851            }
2852        });
2853        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2854            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2855                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2856            }
2857        });
2858        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2859            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2860                copilot_ui::initiate_sign_out(copilot, window, cx);
2861            }
2862        });
2863    })
2864    .detach();
2865}