edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::Edit;
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::onboarding_modal::ZedPredictModal;
  81pub use crate::prediction::EditPrediction;
  82pub use crate::prediction::EditPredictionId;
  83use crate::prediction::EditPredictionResult;
  84pub use crate::sweep_ai::SweepAi;
  85pub use capture_example::capture_example;
  86pub use language_model::ApiKeyState;
  87pub use telemetry_events::EditPredictionRating;
  88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  89
  90actions!(
  91    edit_prediction,
  92    [
  93        /// Resets the edit prediction onboarding state.
  94        ResetOnboarding,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106
 107pub struct Zeta2FeatureFlag;
 108pub struct EditPredictionJumpsFeatureFlag;
 109
 110impl FeatureFlag for Zeta2FeatureFlag {
 111    const NAME: &'static str = "zeta2";
 112}
 113
 114impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 115    const NAME: &'static str = "edit_prediction_jumps";
 116}
 117
 118#[derive(Clone)]
 119struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 120
 121impl Global for EditPredictionStoreGlobal {}
 122
 123/// Configuration for using the raw Zeta2 endpoint.
 124/// When set, the client uses the raw endpoint and constructs the prompt itself.
 125/// The version is also used as the Baseten environment name (lowercased).
 126#[derive(Clone)]
 127pub struct Zeta2RawConfig {
 128    pub model_id: Option<String>,
 129    pub format: ZetaFormat,
 130}
 131
 132pub struct EditPredictionStore {
 133    client: Arc<Client>,
 134    user_store: Entity<UserStore>,
 135    llm_token: LlmApiToken,
 136    _llm_token_subscription: Subscription,
 137    projects: HashMap<EntityId, ProjectState>,
 138    update_required: bool,
 139    edit_prediction_model: EditPredictionModel,
 140    zeta2_raw_config: Option<Zeta2RawConfig>,
 141    pub sweep_ai: SweepAi,
 142    pub mercury: Mercury,
 143    data_collection_choice: DataCollectionChoice,
 144    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 145    shown_predictions: VecDeque<EditPrediction>,
 146    rated_predictions: HashSet<EditPredictionId>,
 147}
 148
 149#[derive(Copy, Clone, PartialEq, Eq)]
 150pub enum EditPredictionModel {
 151    Zeta1,
 152    Zeta2,
 153    Fim { format: EditPredictionPromptFormat },
 154    Sweep,
 155    Mercury,
 156}
 157
 158#[derive(Clone)]
 159pub struct EditPredictionModelInput {
 160    project: Entity<Project>,
 161    buffer: Entity<Buffer>,
 162    snapshot: BufferSnapshot,
 163    position: Anchor,
 164    events: Vec<Arc<zeta_prompt::Event>>,
 165    related_files: Vec<RelatedFile>,
 166    recent_paths: VecDeque<ProjectPath>,
 167    trigger: PredictEditsRequestTrigger,
 168    diagnostic_search_range: Range<Point>,
 169    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 170    pub user_actions: Vec<UserActionRecord>,
 171}
 172
 173#[derive(Debug)]
 174pub enum DebugEvent {
 175    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 176    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 177    EditPredictionStarted(EditPredictionStartedDebugEvent),
 178    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 179}
 180
 181#[derive(Debug)]
 182pub struct ContextRetrievalStartedDebugEvent {
 183    pub project_entity_id: EntityId,
 184    pub timestamp: Instant,
 185    pub search_prompt: String,
 186}
 187
 188#[derive(Debug)]
 189pub struct ContextRetrievalFinishedDebugEvent {
 190    pub project_entity_id: EntityId,
 191    pub timestamp: Instant,
 192    pub metadata: Vec<(&'static str, SharedString)>,
 193}
 194
 195#[derive(Debug)]
 196pub struct EditPredictionStartedDebugEvent {
 197    pub buffer: WeakEntity<Buffer>,
 198    pub position: Anchor,
 199    pub prompt: Option<String>,
 200}
 201
 202#[derive(Debug)]
 203pub struct EditPredictionFinishedDebugEvent {
 204    pub buffer: WeakEntity<Buffer>,
 205    pub position: Anchor,
 206    pub model_output: Option<String>,
 207}
 208
 209const USER_ACTION_HISTORY_SIZE: usize = 16;
 210
 211#[derive(Clone, Debug)]
 212pub struct UserActionRecord {
 213    pub action_type: UserActionType,
 214    pub buffer_id: EntityId,
 215    pub line_number: u32,
 216    pub offset: usize,
 217    pub timestamp_epoch_ms: u64,
 218}
 219
 220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 221pub enum UserActionType {
 222    InsertChar,
 223    InsertSelection,
 224    DeleteChar,
 225    DeleteSelection,
 226    CursorMovement,
 227}
 228
 229/// An event with associated metadata for reconstructing buffer state.
 230#[derive(Clone)]
 231pub struct StoredEvent {
 232    pub event: Arc<zeta_prompt::Event>,
 233    pub old_snapshot: TextBufferSnapshot,
 234    pub edit_range: Range<Anchor>,
 235}
 236
 237impl StoredEvent {
 238    fn can_merge(
 239        &self,
 240        next_old_event: &&&StoredEvent,
 241        new_snapshot: &TextBufferSnapshot,
 242        last_edit_range: &Range<Anchor>,
 243    ) -> bool {
 244        // Events must be for the same buffer
 245        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 246            return false;
 247        }
 248        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 249            return false;
 250        }
 251
 252        let a_is_predicted = matches!(
 253            self.event.as_ref(),
 254            zeta_prompt::Event::BufferChange {
 255                predicted: true,
 256                ..
 257            }
 258        );
 259        let b_is_predicted = matches!(
 260            next_old_event.event.as_ref(),
 261            zeta_prompt::Event::BufferChange {
 262                predicted: true,
 263                ..
 264            }
 265        );
 266
 267        // If events come from the same source (both predicted or both manual) then
 268        // we would have coalesced them already.
 269        if a_is_predicted == b_is_predicted {
 270            return false;
 271        }
 272
 273        let left_range = self.edit_range.to_point(new_snapshot);
 274        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 275        let latest_range = last_edit_range.to_point(&new_snapshot);
 276
 277        // Events near to the latest edit are not merged if their sources differ.
 278        if lines_between_ranges(&left_range, &latest_range)
 279            .min(lines_between_ranges(&right_range, &latest_range))
 280            <= CHANGE_GROUPING_LINE_SPAN
 281        {
 282            return false;
 283        }
 284
 285        // Events that are distant from each other are not merged.
 286        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 287            return false;
 288        }
 289
 290        true
 291    }
 292}
 293
 294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 295    if left.start > right.end {
 296        return left.start.row - right.end.row;
 297    }
 298    if right.start > left.end {
 299        return right.start.row - left.end.row;
 300    }
 301    0
 302}
 303
 304struct ProjectState {
 305    events: VecDeque<StoredEvent>,
 306    last_event: Option<LastEvent>,
 307    recent_paths: VecDeque<ProjectPath>,
 308    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 309    current_prediction: Option<CurrentEditPrediction>,
 310    next_pending_prediction_id: usize,
 311    pending_predictions: ArrayVec<PendingPrediction, 2>,
 312    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 313    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 314    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 315    cancelled_predictions: HashSet<usize>,
 316    context: Entity<RelatedExcerptStore>,
 317    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 318    user_actions: VecDeque<UserActionRecord>,
 319    _subscriptions: [gpui::Subscription; 2],
 320    copilot: Option<Entity<Copilot>>,
 321}
 322
 323impl ProjectState {
 324    fn record_user_action(&mut self, action: UserActionRecord) {
 325        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 326            self.user_actions.pop_front();
 327        }
 328        self.user_actions.push_back(action);
 329    }
 330
 331    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 332        self.events
 333            .iter()
 334            .cloned()
 335            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 336                let (one, two) = event.split_by_pause();
 337                let one = one.finalize(&self.license_detection_watchers, cx);
 338                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 339                one.into_iter().chain(two)
 340            }))
 341            .collect()
 342    }
 343
 344    fn cancel_pending_prediction(
 345        &mut self,
 346        pending_prediction: PendingPrediction,
 347        cx: &mut Context<EditPredictionStore>,
 348    ) {
 349        self.cancelled_predictions.insert(pending_prediction.id);
 350
 351        if pending_prediction.drop_on_cancel {
 352            drop(pending_prediction.task);
 353        } else {
 354            cx.spawn(async move |this, cx| {
 355                let Some(prediction_id) = pending_prediction.task.await else {
 356                    return;
 357                };
 358
 359                this.update(cx, |this, cx| {
 360                    this.reject_prediction(
 361                        prediction_id,
 362                        EditPredictionRejectReason::Canceled,
 363                        false,
 364                        None,
 365                        cx,
 366                    );
 367                })
 368                .ok();
 369            })
 370            .detach()
 371        }
 372    }
 373
 374    fn active_buffer(
 375        &self,
 376        project: &Entity<Project>,
 377        cx: &App,
 378    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 379        let project = project.read(cx);
 380        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 381        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 382        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 383        Some((active_buffer, registered_buffer.last_position))
 384    }
 385}
 386
 387#[derive(Debug, Clone)]
 388struct CurrentEditPrediction {
 389    pub requested_by: PredictionRequestedBy,
 390    pub prediction: EditPrediction,
 391    pub was_shown: bool,
 392    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 393}
 394
 395impl CurrentEditPrediction {
 396    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 397        let Some(new_edits) = self
 398            .prediction
 399            .interpolate(&self.prediction.buffer.read(cx))
 400        else {
 401            return false;
 402        };
 403
 404        if self.prediction.buffer != old_prediction.prediction.buffer {
 405            return true;
 406        }
 407
 408        let Some(old_edits) = old_prediction
 409            .prediction
 410            .interpolate(&old_prediction.prediction.buffer.read(cx))
 411        else {
 412            return true;
 413        };
 414
 415        let requested_by_buffer_id = self.requested_by.buffer_id();
 416
 417        // This reduces the occurrence of UI thrash from replacing edits
 418        //
 419        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 420        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 421            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 422            && old_edits.len() == 1
 423            && new_edits.len() == 1
 424        {
 425            let (old_range, old_text) = &old_edits[0];
 426            let (new_range, new_text) = &new_edits[0];
 427            new_range == old_range && new_text.starts_with(old_text.as_ref())
 428        } else {
 429            true
 430        }
 431    }
 432}
 433
 434#[derive(Debug, Clone)]
 435enum PredictionRequestedBy {
 436    DiagnosticsUpdate,
 437    Buffer(EntityId),
 438}
 439
 440impl PredictionRequestedBy {
 441    pub fn buffer_id(&self) -> Option<EntityId> {
 442        match self {
 443            PredictionRequestedBy::DiagnosticsUpdate => None,
 444            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 445        }
 446    }
 447}
 448
 449const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 450
 451#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 452pub enum DiagnosticSearchScope {
 453    Local,
 454    Global,
 455}
 456
 457#[derive(Debug)]
 458struct PendingPrediction {
 459    id: usize,
 460    task: Task<Option<EditPredictionId>>,
 461    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 462    /// If false, the task is awaited to completion so rejection can be reported.
 463    drop_on_cancel: bool,
 464}
 465
 466/// A prediction from the perspective of a buffer.
 467#[derive(Debug)]
 468enum BufferEditPrediction<'a> {
 469    Local { prediction: &'a EditPrediction },
 470    Jump { prediction: &'a EditPrediction },
 471}
 472
 473#[cfg(test)]
 474impl std::ops::Deref for BufferEditPrediction<'_> {
 475    type Target = EditPrediction;
 476
 477    fn deref(&self) -> &Self::Target {
 478        match self {
 479            BufferEditPrediction::Local { prediction } => prediction,
 480            BufferEditPrediction::Jump { prediction } => prediction,
 481        }
 482    }
 483}
 484
 485struct RegisteredBuffer {
 486    file: Option<Arc<dyn File>>,
 487    snapshot: TextBufferSnapshot,
 488    last_position: Option<Anchor>,
 489    _subscriptions: [gpui::Subscription; 2],
 490}
 491
 492#[derive(Clone)]
 493struct LastEvent {
 494    old_snapshot: TextBufferSnapshot,
 495    new_snapshot: TextBufferSnapshot,
 496    old_file: Option<Arc<dyn File>>,
 497    new_file: Option<Arc<dyn File>>,
 498    edit_range: Option<Range<Anchor>>,
 499    predicted: bool,
 500    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 501    last_edit_time: Option<Instant>,
 502}
 503
 504impl LastEvent {
 505    pub fn finalize(
 506        &self,
 507        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 508        cx: &App,
 509    ) -> Option<StoredEvent> {
 510        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 511        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 512
 513        let in_open_source_repo =
 514            [self.new_file.as_ref(), self.old_file.as_ref()]
 515                .iter()
 516                .all(|file| {
 517                    file.is_some_and(|file| {
 518                        license_detection_watchers
 519                            .get(&file.worktree_id(cx))
 520                            .is_some_and(|watcher| watcher.is_project_open_source())
 521                    })
 522                });
 523
 524        let (diff, edit_range) =
 525            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 526
 527        if path == old_path && diff.is_empty() {
 528            None
 529        } else {
 530            Some(StoredEvent {
 531                event: Arc::new(zeta_prompt::Event::BufferChange {
 532                    old_path,
 533                    path,
 534                    diff,
 535                    in_open_source_repo,
 536                    predicted: self.predicted,
 537                }),
 538                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 539                    ..self.new_snapshot.anchor_before(edit_range.end),
 540                old_snapshot: self.old_snapshot.clone(),
 541            })
 542        }
 543    }
 544
 545    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 546        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 547            return (self.clone(), None);
 548        };
 549
 550        let before = LastEvent {
 551            old_snapshot: self.old_snapshot.clone(),
 552            new_snapshot: boundary_snapshot.clone(),
 553            old_file: self.old_file.clone(),
 554            new_file: self.new_file.clone(),
 555            edit_range: None,
 556            predicted: self.predicted,
 557            snapshot_after_last_editing_pause: None,
 558            last_edit_time: self.last_edit_time,
 559        };
 560
 561        let after = LastEvent {
 562            old_snapshot: boundary_snapshot.clone(),
 563            new_snapshot: self.new_snapshot.clone(),
 564            old_file: self.old_file.clone(),
 565            new_file: self.new_file.clone(),
 566            edit_range: None,
 567            predicted: self.predicted,
 568            snapshot_after_last_editing_pause: None,
 569            last_edit_time: self.last_edit_time,
 570        };
 571
 572        (before, Some(after))
 573    }
 574}
 575
 576pub(crate) fn compute_diff_between_snapshots(
 577    old_snapshot: &TextBufferSnapshot,
 578    new_snapshot: &TextBufferSnapshot,
 579) -> Option<(String, Range<Point>)> {
 580    let edits: Vec<Edit<usize>> = new_snapshot
 581        .edits_since::<usize>(&old_snapshot.version)
 582        .collect();
 583
 584    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 585
 586    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 587    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 588    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 589    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 590
 591    const CONTEXT_LINES: u32 = 3;
 592
 593    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 594    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 595    let old_context_end_row =
 596        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 597    let new_context_end_row =
 598        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 599
 600    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 601    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 602    let old_end_line_offset = old_snapshot
 603        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 604    let new_end_line_offset = new_snapshot
 605        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 606    let old_edit_range = old_start_line_offset..old_end_line_offset;
 607    let new_edit_range = new_start_line_offset..new_end_line_offset;
 608
 609    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 610    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 611
 612    let diff = language::unified_diff_with_offsets(
 613        &old_region_text,
 614        &new_region_text,
 615        old_context_start_row,
 616        new_context_start_row,
 617    );
 618
 619    Some((diff, new_start_point..new_end_point))
 620}
 621
 622fn buffer_path_with_id_fallback(
 623    file: Option<&Arc<dyn File>>,
 624    snapshot: &TextBufferSnapshot,
 625    cx: &App,
 626) -> Arc<Path> {
 627    if let Some(file) = file {
 628        file.full_path(cx).into()
 629    } else {
 630        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 631    }
 632}
 633
 634impl EditPredictionStore {
 635    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 636        cx.try_global::<EditPredictionStoreGlobal>()
 637            .map(|global| global.0.clone())
 638    }
 639
 640    pub fn global(
 641        client: &Arc<Client>,
 642        user_store: &Entity<UserStore>,
 643        cx: &mut App,
 644    ) -> Entity<Self> {
 645        cx.try_global::<EditPredictionStoreGlobal>()
 646            .map(|global| global.0.clone())
 647            .unwrap_or_else(|| {
 648                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 649                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 650                ep_store
 651            })
 652    }
 653
 654    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 655        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 656        let data_collection_choice = Self::load_data_collection_choice();
 657
 658        let llm_token = LlmApiToken::default();
 659
 660        let (reject_tx, reject_rx) = mpsc::unbounded();
 661        cx.background_spawn({
 662            let client = client.clone();
 663            let llm_token = llm_token.clone();
 664            let app_version = AppVersion::global(cx);
 665            let background_executor = cx.background_executor().clone();
 666            async move {
 667                Self::handle_rejected_predictions(
 668                    reject_rx,
 669                    client,
 670                    llm_token,
 671                    app_version,
 672                    background_executor,
 673                )
 674                .await
 675            }
 676        })
 677        .detach();
 678
 679        let this = Self {
 680            projects: HashMap::default(),
 681            client,
 682            user_store,
 683            llm_token,
 684            _llm_token_subscription: cx.subscribe(
 685                &refresh_llm_token_listener,
 686                |this, _listener, _event, cx| {
 687                    let client = this.client.clone();
 688                    let llm_token = this.llm_token.clone();
 689                    cx.spawn(async move |_this, _cx| {
 690                        llm_token.refresh(&client).await?;
 691                        anyhow::Ok(())
 692                    })
 693                    .detach_and_log_err(cx);
 694                },
 695            ),
 696            update_required: false,
 697            edit_prediction_model: EditPredictionModel::Zeta2,
 698            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 699            sweep_ai: SweepAi::new(cx),
 700            mercury: Mercury::new(cx),
 701
 702            data_collection_choice,
 703            reject_predictions_tx: reject_tx,
 704            rated_predictions: Default::default(),
 705            shown_predictions: Default::default(),
 706        };
 707
 708        this
 709    }
 710
 711    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 712        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 713        let format = ZetaFormat::parse(&version_str).ok()?;
 714        let model_id = env::var("ZED_ZETA_MODEL").ok();
 715        Some(Zeta2RawConfig { model_id, format })
 716    }
 717
 718    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 719        self.edit_prediction_model = model;
 720    }
 721
 722    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 723        self.zeta2_raw_config = Some(config);
 724    }
 725
 726    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 727        self.zeta2_raw_config.as_ref()
 728    }
 729
 730    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 731        use ui::IconName;
 732        match self.edit_prediction_model {
 733            EditPredictionModel::Sweep => {
 734                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 735                    .with_disabled(IconName::SweepAiDisabled)
 736                    .with_up(IconName::SweepAiUp)
 737                    .with_down(IconName::SweepAiDown)
 738                    .with_error(IconName::SweepAiError)
 739            }
 740            EditPredictionModel::Mercury => {
 741                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 742            }
 743            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 744                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 745                    .with_disabled(IconName::ZedPredictDisabled)
 746                    .with_up(IconName::ZedPredictUp)
 747                    .with_down(IconName::ZedPredictDown)
 748                    .with_error(IconName::ZedPredictError)
 749            }
 750            EditPredictionModel::Fim { .. } => {
 751                let settings = &all_language_settings(None, cx).edit_predictions;
 752                match settings.provider {
 753                    EditPredictionProvider::Ollama => {
 754                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 755                    }
 756                    _ => {
 757                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 758                    }
 759                }
 760            }
 761        }
 762    }
 763
 764    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 765        self.sweep_ai.api_token.read(cx).has_key()
 766    }
 767
 768    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 769        self.mercury.api_token.read(cx).has_key()
 770    }
 771
 772    pub fn clear_history(&mut self) {
 773        for project_state in self.projects.values_mut() {
 774            project_state.events.clear();
 775            project_state.last_event.take();
 776        }
 777    }
 778
 779    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 780        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 781            project_state.events.clear();
 782            project_state.last_event.take();
 783        }
 784    }
 785
 786    pub fn edit_history_for_project(
 787        &self,
 788        project: &Entity<Project>,
 789        cx: &App,
 790    ) -> Vec<StoredEvent> {
 791        self.projects
 792            .get(&project.entity_id())
 793            .map(|project_state| project_state.events(cx))
 794            .unwrap_or_default()
 795    }
 796
 797    pub fn context_for_project<'a>(
 798        &'a self,
 799        project: &Entity<Project>,
 800        cx: &'a mut App,
 801    ) -> Vec<RelatedFile> {
 802        self.projects
 803            .get(&project.entity_id())
 804            .map(|project_state| {
 805                project_state.context.update(cx, |context, cx| {
 806                    context
 807                        .related_files_with_buffers(cx)
 808                        .map(|(mut related_file, buffer)| {
 809                            related_file.in_open_source_repo = buffer
 810                                .read(cx)
 811                                .file()
 812                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 813                            related_file
 814                        })
 815                        .collect()
 816                })
 817            })
 818            .unwrap_or_default()
 819    }
 820
 821    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 822        self.projects
 823            .get(&project.entity_id())
 824            .and_then(|project| project.copilot.clone())
 825    }
 826
 827    pub fn start_copilot_for_project(
 828        &mut self,
 829        project: &Entity<Project>,
 830        cx: &mut Context<Self>,
 831    ) -> Option<Entity<Copilot>> {
 832        if DisableAiSettings::get(None, cx).disable_ai {
 833            return None;
 834        }
 835        let state = self.get_or_init_project(project, cx);
 836
 837        if state.copilot.is_some() {
 838            return state.copilot.clone();
 839        }
 840        let _project = project.clone();
 841        let project = project.read(cx);
 842
 843        let node = project.node_runtime().cloned();
 844        if let Some(node) = node {
 845            let next_id = project.languages().next_language_server_id();
 846            let fs = project.fs().clone();
 847
 848            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 849            state.copilot = Some(copilot.clone());
 850            Some(copilot)
 851        } else {
 852            None
 853        }
 854    }
 855
 856    pub fn context_for_project_with_buffers<'a>(
 857        &'a self,
 858        project: &Entity<Project>,
 859        cx: &'a mut App,
 860    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 861        self.projects
 862            .get(&project.entity_id())
 863            .map(|project| {
 864                project.context.update(cx, |context, cx| {
 865                    context.related_files_with_buffers(cx).collect()
 866                })
 867            })
 868            .unwrap_or_default()
 869    }
 870
 871    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 872        if matches!(
 873            self.edit_prediction_model,
 874            EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
 875        ) {
 876            self.user_store.read(cx).edit_prediction_usage()
 877        } else {
 878            None
 879        }
 880    }
 881
 882    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 883        self.get_or_init_project(project, cx);
 884    }
 885
 886    pub fn register_buffer(
 887        &mut self,
 888        buffer: &Entity<Buffer>,
 889        project: &Entity<Project>,
 890        cx: &mut Context<Self>,
 891    ) {
 892        let project_state = self.get_or_init_project(project, cx);
 893        Self::register_buffer_impl(project_state, buffer, project, cx);
 894    }
 895
 896    fn get_or_init_project(
 897        &mut self,
 898        project: &Entity<Project>,
 899        cx: &mut Context<Self>,
 900    ) -> &mut ProjectState {
 901        let entity_id = project.entity_id();
 902        self.projects
 903            .entry(entity_id)
 904            .or_insert_with(|| ProjectState {
 905                context: {
 906                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 907                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 908                        this.handle_excerpt_store_event(entity_id, event);
 909                    })
 910                    .detach();
 911                    related_excerpt_store
 912                },
 913                events: VecDeque::new(),
 914                last_event: None,
 915                recent_paths: VecDeque::new(),
 916                debug_tx: None,
 917                registered_buffers: HashMap::default(),
 918                current_prediction: None,
 919                cancelled_predictions: HashSet::default(),
 920                pending_predictions: ArrayVec::new(),
 921                next_pending_prediction_id: 0,
 922                last_edit_prediction_refresh: None,
 923                last_jump_prediction_refresh: None,
 924                license_detection_watchers: HashMap::default(),
 925                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 926                _subscriptions: [
 927                    cx.subscribe(&project, Self::handle_project_event),
 928                    cx.observe_release(&project, move |this, _, cx| {
 929                        this.projects.remove(&entity_id);
 930                        cx.notify();
 931                    }),
 932                ],
 933                copilot: None,
 934            })
 935    }
 936
 937    pub fn remove_project(&mut self, project: &Entity<Project>) {
 938        self.projects.remove(&project.entity_id());
 939    }
 940
 941    fn handle_excerpt_store_event(
 942        &mut self,
 943        project_entity_id: EntityId,
 944        event: &RelatedExcerptStoreEvent,
 945    ) {
 946        if let Some(project_state) = self.projects.get(&project_entity_id) {
 947            if let Some(debug_tx) = project_state.debug_tx.clone() {
 948                match event {
 949                    RelatedExcerptStoreEvent::StartedRefresh => {
 950                        debug_tx
 951                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 952                                ContextRetrievalStartedDebugEvent {
 953                                    project_entity_id: project_entity_id,
 954                                    timestamp: Instant::now(),
 955                                    search_prompt: String::new(),
 956                                },
 957                            ))
 958                            .ok();
 959                    }
 960                    RelatedExcerptStoreEvent::FinishedRefresh {
 961                        cache_hit_count,
 962                        cache_miss_count,
 963                        mean_definition_latency,
 964                        max_definition_latency,
 965                    } => {
 966                        debug_tx
 967                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 968                                ContextRetrievalFinishedDebugEvent {
 969                                    project_entity_id: project_entity_id,
 970                                    timestamp: Instant::now(),
 971                                    metadata: vec![
 972                                        (
 973                                            "Cache Hits",
 974                                            format!(
 975                                                "{}/{}",
 976                                                cache_hit_count,
 977                                                cache_hit_count + cache_miss_count
 978                                            )
 979                                            .into(),
 980                                        ),
 981                                        (
 982                                            "Max LSP Time",
 983                                            format!("{} ms", max_definition_latency.as_millis())
 984                                                .into(),
 985                                        ),
 986                                        (
 987                                            "Mean LSP Time",
 988                                            format!("{} ms", mean_definition_latency.as_millis())
 989                                                .into(),
 990                                        ),
 991                                    ],
 992                                },
 993                            ))
 994                            .ok();
 995                    }
 996                }
 997            }
 998        }
 999    }
1000
1001    pub fn debug_info(
1002        &mut self,
1003        project: &Entity<Project>,
1004        cx: &mut Context<Self>,
1005    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1006        let project_state = self.get_or_init_project(project, cx);
1007        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1008        project_state.debug_tx = Some(debug_watch_tx);
1009        debug_watch_rx
1010    }
1011
1012    fn handle_project_event(
1013        &mut self,
1014        project: Entity<Project>,
1015        event: &project::Event,
1016        cx: &mut Context<Self>,
1017    ) {
1018        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1019            return;
1020        }
1021        // TODO [zeta2] init with recent paths
1022        match event {
1023            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1024                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1025                    return;
1026                };
1027                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1028                if let Some(path) = path {
1029                    if let Some(ix) = project_state
1030                        .recent_paths
1031                        .iter()
1032                        .position(|probe| probe == &path)
1033                    {
1034                        project_state.recent_paths.remove(ix);
1035                    }
1036                    project_state.recent_paths.push_front(path);
1037                }
1038            }
1039            project::Event::DiagnosticsUpdated { .. } => {
1040                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1041                    self.refresh_prediction_from_diagnostics(
1042                        project,
1043                        DiagnosticSearchScope::Global,
1044                        cx,
1045                    );
1046                }
1047            }
1048            _ => (),
1049        }
1050    }
1051
1052    fn register_buffer_impl<'a>(
1053        project_state: &'a mut ProjectState,
1054        buffer: &Entity<Buffer>,
1055        project: &Entity<Project>,
1056        cx: &mut Context<Self>,
1057    ) -> &'a mut RegisteredBuffer {
1058        let buffer_id = buffer.entity_id();
1059
1060        if let Some(file) = buffer.read(cx).file() {
1061            let worktree_id = file.worktree_id(cx);
1062            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1063                project_state
1064                    .license_detection_watchers
1065                    .entry(worktree_id)
1066                    .or_insert_with(|| {
1067                        let project_entity_id = project.entity_id();
1068                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1069                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1070                            else {
1071                                return;
1072                            };
1073                            project_state
1074                                .license_detection_watchers
1075                                .remove(&worktree_id);
1076                        })
1077                        .detach();
1078                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1079                    });
1080            }
1081        }
1082
1083        match project_state.registered_buffers.entry(buffer_id) {
1084            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1085            hash_map::Entry::Vacant(entry) => {
1086                let buf = buffer.read(cx);
1087                let snapshot = buf.text_snapshot();
1088                let file = buf.file().cloned();
1089                let project_entity_id = project.entity_id();
1090                entry.insert(RegisteredBuffer {
1091                    snapshot,
1092                    file,
1093                    last_position: None,
1094                    _subscriptions: [
1095                        cx.subscribe(buffer, {
1096                            let project = project.downgrade();
1097                            move |this, buffer, event, cx| {
1098                                if let language::BufferEvent::Edited = event
1099                                    && let Some(project) = project.upgrade()
1100                                {
1101                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1102                                }
1103                            }
1104                        }),
1105                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1106                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1107                            else {
1108                                return;
1109                            };
1110                            project_state.registered_buffers.remove(&buffer_id);
1111                        }),
1112                    ],
1113                })
1114            }
1115        }
1116    }
1117
1118    fn report_changes_for_buffer(
1119        &mut self,
1120        buffer: &Entity<Buffer>,
1121        project: &Entity<Project>,
1122        is_predicted: bool,
1123        cx: &mut Context<Self>,
1124    ) {
1125        let project_state = self.get_or_init_project(project, cx);
1126        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1127
1128        let buf = buffer.read(cx);
1129        let new_file = buf.file().cloned();
1130        let new_snapshot = buf.text_snapshot();
1131        if new_snapshot.version == registered_buffer.snapshot.version {
1132            return;
1133        }
1134
1135        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1136        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1137        let mut num_edits = 0usize;
1138        let mut total_deleted = 0usize;
1139        let mut total_inserted = 0usize;
1140        let mut edit_range: Option<Range<Anchor>> = None;
1141        let mut last_offset: Option<usize> = None;
1142
1143        for (edit, anchor_range) in
1144            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1145        {
1146            num_edits += 1;
1147            total_deleted += edit.old.len();
1148            total_inserted += edit.new.len();
1149            edit_range = Some(match edit_range {
1150                None => anchor_range,
1151                Some(acc) => acc.start..anchor_range.end,
1152            });
1153            last_offset = Some(edit.new.end);
1154        }
1155
1156        let Some(edit_range) = edit_range else {
1157            return;
1158        };
1159
1160        let action_type = match (total_deleted, total_inserted, num_edits) {
1161            (0, ins, n) if ins == n => UserActionType::InsertChar,
1162            (0, _, _) => UserActionType::InsertSelection,
1163            (del, 0, n) if del == n => UserActionType::DeleteChar,
1164            (_, 0, _) => UserActionType::DeleteSelection,
1165            (_, ins, n) if ins == n => UserActionType::InsertChar,
1166            (_, _, _) => UserActionType::InsertSelection,
1167        };
1168
1169        if let Some(offset) = last_offset {
1170            let point = new_snapshot.offset_to_point(offset);
1171            let timestamp_epoch_ms = SystemTime::now()
1172                .duration_since(UNIX_EPOCH)
1173                .map(|d| d.as_millis() as u64)
1174                .unwrap_or(0);
1175            project_state.record_user_action(UserActionRecord {
1176                action_type,
1177                buffer_id: buffer.entity_id(),
1178                line_number: point.row,
1179                offset,
1180                timestamp_epoch_ms,
1181            });
1182        }
1183
1184        let events = &mut project_state.events;
1185
1186        let now = cx.background_executor().now();
1187        if let Some(last_event) = project_state.last_event.as_mut() {
1188            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1189                == last_event.new_snapshot.remote_id()
1190                && old_snapshot.version == last_event.new_snapshot.version;
1191
1192            let prediction_source_changed = is_predicted != last_event.predicted;
1193
1194            let should_coalesce = is_next_snapshot_of_same_buffer
1195                && !prediction_source_changed
1196                && last_event
1197                    .edit_range
1198                    .as_ref()
1199                    .is_some_and(|last_edit_range| {
1200                        lines_between_ranges(
1201                            &edit_range.to_point(&new_snapshot),
1202                            &last_edit_range.to_point(&new_snapshot),
1203                        ) <= CHANGE_GROUPING_LINE_SPAN
1204                    });
1205
1206            if should_coalesce {
1207                let pause_elapsed = last_event
1208                    .last_edit_time
1209                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1210                    .unwrap_or(false);
1211                if pause_elapsed {
1212                    last_event.snapshot_after_last_editing_pause =
1213                        Some(last_event.new_snapshot.clone());
1214                }
1215
1216                last_event.edit_range = Some(edit_range);
1217                last_event.new_snapshot = new_snapshot;
1218                last_event.last_edit_time = Some(now);
1219                return;
1220            }
1221        }
1222
1223        if let Some(event) = project_state.last_event.take() {
1224            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1225                if events.len() + 1 >= EVENT_COUNT_MAX {
1226                    events.pop_front();
1227                }
1228                events.push_back(event);
1229            }
1230        }
1231
1232        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1233
1234        project_state.last_event = Some(LastEvent {
1235            old_file,
1236            new_file,
1237            old_snapshot,
1238            new_snapshot,
1239            edit_range: Some(edit_range),
1240            predicted: is_predicted,
1241            snapshot_after_last_editing_pause: None,
1242            last_edit_time: Some(now),
1243        });
1244    }
1245
1246    fn prediction_at(
1247        &mut self,
1248        buffer: &Entity<Buffer>,
1249        position: Option<language::Anchor>,
1250        project: &Entity<Project>,
1251        cx: &App,
1252    ) -> Option<BufferEditPrediction<'_>> {
1253        let project_state = self.projects.get_mut(&project.entity_id())?;
1254        if let Some(position) = position
1255            && let Some(buffer) = project_state
1256                .registered_buffers
1257                .get_mut(&buffer.entity_id())
1258        {
1259            buffer.last_position = Some(position);
1260        }
1261
1262        let CurrentEditPrediction {
1263            requested_by,
1264            prediction,
1265            ..
1266        } = project_state.current_prediction.as_ref()?;
1267
1268        if prediction.targets_buffer(buffer.read(cx)) {
1269            Some(BufferEditPrediction::Local { prediction })
1270        } else {
1271            let show_jump = match requested_by {
1272                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1273                    requested_by_buffer_id == &buffer.entity_id()
1274                }
1275                PredictionRequestedBy::DiagnosticsUpdate => true,
1276            };
1277
1278            if show_jump {
1279                Some(BufferEditPrediction::Jump { prediction })
1280            } else {
1281                None
1282            }
1283        }
1284    }
1285
1286    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1287        let Some(current_prediction) = self
1288            .projects
1289            .get_mut(&project.entity_id())
1290            .and_then(|project_state| project_state.current_prediction.take())
1291        else {
1292            return;
1293        };
1294
1295        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1296
1297        // can't hold &mut project_state ref across report_changes_for_buffer_call
1298        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1299            return;
1300        };
1301
1302        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1303            project_state.cancel_pending_prediction(pending_prediction, cx);
1304        }
1305
1306        match self.edit_prediction_model {
1307            EditPredictionModel::Sweep => {
1308                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1309            }
1310            EditPredictionModel::Mercury => {
1311                mercury::edit_prediction_accepted(
1312                    current_prediction.prediction.id,
1313                    self.client.http_client(),
1314                    cx,
1315                );
1316            }
1317            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1318                let is_cloud = !matches!(
1319                    all_language_settings(None, cx).edit_predictions.provider,
1320                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1321                );
1322                if is_cloud {
1323                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1324                }
1325            }
1326            EditPredictionModel::Fim { .. } => {}
1327        }
1328    }
1329
1330    async fn handle_rejected_predictions(
1331        rx: UnboundedReceiver<EditPredictionRejection>,
1332        client: Arc<Client>,
1333        llm_token: LlmApiToken,
1334        app_version: Version,
1335        background_executor: BackgroundExecutor,
1336    ) {
1337        let mut rx = std::pin::pin!(rx.peekable());
1338        let mut batched = Vec::new();
1339
1340        while let Some(rejection) = rx.next().await {
1341            batched.push(rejection);
1342
1343            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1344                select_biased! {
1345                    next = rx.as_mut().peek().fuse() => {
1346                        if next.is_some() {
1347                            continue;
1348                        }
1349                    }
1350                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1351                }
1352            }
1353
1354            let url = client
1355                .http_client()
1356                .build_zed_llm_url("/predict_edits/reject", &[])
1357                .unwrap();
1358
1359            let flush_count = batched
1360                .len()
1361                // in case items have accumulated after failure
1362                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1363            let start = batched.len() - flush_count;
1364
1365            let body = RejectEditPredictionsBodyRef {
1366                rejections: &batched[start..],
1367            };
1368
1369            let result = Self::send_api_request::<()>(
1370                |builder| {
1371                    let req = builder
1372                        .uri(url.as_ref())
1373                        .body(serde_json::to_string(&body)?.into());
1374                    anyhow::Ok(req?)
1375                },
1376                client.clone(),
1377                llm_token.clone(),
1378                app_version.clone(),
1379                true,
1380            )
1381            .await;
1382
1383            if result.log_err().is_some() {
1384                batched.drain(start..);
1385            }
1386        }
1387    }
1388
1389    fn reject_current_prediction(
1390        &mut self,
1391        reason: EditPredictionRejectReason,
1392        project: &Entity<Project>,
1393        cx: &App,
1394    ) {
1395        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1396            project_state.pending_predictions.clear();
1397            if let Some(prediction) = project_state.current_prediction.take() {
1398                let model_version = prediction.prediction.model_version.clone();
1399                self.reject_prediction(
1400                    prediction.prediction.id,
1401                    reason,
1402                    prediction.was_shown,
1403                    model_version,
1404                    cx,
1405                );
1406            }
1407        };
1408    }
1409
1410    fn did_show_current_prediction(
1411        &mut self,
1412        project: &Entity<Project>,
1413        display_type: edit_prediction_types::SuggestionDisplayType,
1414        cx: &mut Context<Self>,
1415    ) {
1416        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1417            return;
1418        };
1419
1420        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1421            return;
1422        };
1423
1424        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1425        let previous_shown_with = current_prediction.shown_with;
1426
1427        if previous_shown_with.is_none() || !is_jump {
1428            current_prediction.shown_with = Some(display_type);
1429        }
1430
1431        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1432
1433        if is_first_non_jump_show {
1434            current_prediction.was_shown = true;
1435        }
1436
1437        let display_type_changed = previous_shown_with != Some(display_type);
1438
1439        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1440            sweep_ai::edit_prediction_shown(
1441                &self.sweep_ai,
1442                self.client.clone(),
1443                &current_prediction.prediction,
1444                display_type,
1445                cx,
1446            );
1447        }
1448
1449        if is_first_non_jump_show {
1450            self.shown_predictions
1451                .push_front(current_prediction.prediction.clone());
1452            if self.shown_predictions.len() > 50 {
1453                let completion = self.shown_predictions.pop_back().unwrap();
1454                self.rated_predictions.remove(&completion.id);
1455            }
1456        }
1457    }
1458
1459    fn reject_prediction(
1460        &mut self,
1461        prediction_id: EditPredictionId,
1462        reason: EditPredictionRejectReason,
1463        was_shown: bool,
1464        model_version: Option<String>,
1465        cx: &App,
1466    ) {
1467        match self.edit_prediction_model {
1468            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1469                let is_cloud = !matches!(
1470                    all_language_settings(None, cx).edit_predictions.provider,
1471                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1472                );
1473                if is_cloud {
1474                    self.reject_predictions_tx
1475                        .unbounded_send(EditPredictionRejection {
1476                            request_id: prediction_id.to_string(),
1477                            reason,
1478                            was_shown,
1479                            model_version,
1480                        })
1481                        .log_err();
1482                }
1483            }
1484            EditPredictionModel::Mercury => {
1485                mercury::edit_prediction_rejected(
1486                    prediction_id,
1487                    was_shown,
1488                    reason,
1489                    self.client.http_client(),
1490                    cx,
1491                );
1492            }
1493            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1494        }
1495    }
1496
1497    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1498        self.projects
1499            .get(&project.entity_id())
1500            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1501    }
1502
1503    pub fn refresh_prediction_from_buffer(
1504        &mut self,
1505        project: Entity<Project>,
1506        buffer: Entity<Buffer>,
1507        position: language::Anchor,
1508        cx: &mut Context<Self>,
1509    ) {
1510        self.queue_prediction_refresh(
1511            project.clone(),
1512            PredictEditsRequestTrigger::Other,
1513            buffer.entity_id(),
1514            cx,
1515            move |this, cx| {
1516                let Some(request_task) = this
1517                    .update(cx, |this, cx| {
1518                        this.request_prediction(
1519                            &project,
1520                            &buffer,
1521                            position,
1522                            PredictEditsRequestTrigger::Other,
1523                            cx,
1524                        )
1525                    })
1526                    .log_err()
1527                else {
1528                    return Task::ready(anyhow::Ok(None));
1529                };
1530
1531                cx.spawn(async move |_cx| {
1532                    request_task.await.map(|prediction_result| {
1533                        prediction_result.map(|prediction_result| {
1534                            (
1535                                prediction_result,
1536                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1537                            )
1538                        })
1539                    })
1540                })
1541            },
1542        )
1543    }
1544
1545    pub fn refresh_prediction_from_diagnostics(
1546        &mut self,
1547        project: Entity<Project>,
1548        scope: DiagnosticSearchScope,
1549        cx: &mut Context<Self>,
1550    ) {
1551        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1552            return;
1553        }
1554
1555        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1556            return;
1557        };
1558
1559        // Prefer predictions from buffer
1560        if project_state.current_prediction.is_some() {
1561            return;
1562        }
1563
1564        self.queue_prediction_refresh(
1565            project.clone(),
1566            PredictEditsRequestTrigger::Diagnostics,
1567            project.entity_id(),
1568            cx,
1569            move |this, cx| {
1570                let Some((active_buffer, snapshot, cursor_point)) = this
1571                    .read_with(cx, |this, cx| {
1572                        let project_state = this.projects.get(&project.entity_id())?;
1573                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1574                        let snapshot = buffer.read(cx).snapshot();
1575
1576                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1577                            return None;
1578                        }
1579
1580                        let cursor_point = position
1581                            .map(|pos| pos.to_point(&snapshot))
1582                            .unwrap_or_default();
1583
1584                        Some((buffer, snapshot, cursor_point))
1585                    })
1586                    .log_err()
1587                    .flatten()
1588                else {
1589                    return Task::ready(anyhow::Ok(None));
1590                };
1591
1592                cx.spawn(async move |cx| {
1593                    let diagnostic_search_range = match scope {
1594                        DiagnosticSearchScope::Local => {
1595                            let diagnostic_search_start =
1596                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1597                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1598                            Point::new(diagnostic_search_start, 0)
1599                                ..Point::new(diagnostic_search_end, 0)
1600                        }
1601                        DiagnosticSearchScope::Global => Default::default(),
1602                    };
1603
1604                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1605                        active_buffer,
1606                        &snapshot,
1607                        diagnostic_search_range,
1608                        cursor_point,
1609                        &project,
1610                        cx,
1611                    )
1612                    .await?
1613                    else {
1614                        return anyhow::Ok(None);
1615                    };
1616
1617                    let Some(prediction_result) = this
1618                        .update(cx, |this, cx| {
1619                            this.request_prediction(
1620                                &project,
1621                                &jump_buffer,
1622                                jump_position,
1623                                PredictEditsRequestTrigger::Diagnostics,
1624                                cx,
1625                            )
1626                        })?
1627                        .await?
1628                    else {
1629                        return anyhow::Ok(None);
1630                    };
1631
1632                    this.update(cx, |this, cx| {
1633                        Some((
1634                            if this
1635                                .get_or_init_project(&project, cx)
1636                                .current_prediction
1637                                .is_none()
1638                            {
1639                                prediction_result
1640                            } else {
1641                                EditPredictionResult {
1642                                    id: prediction_result.id,
1643                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1644                                }
1645                            },
1646                            PredictionRequestedBy::DiagnosticsUpdate,
1647                        ))
1648                    })
1649                })
1650            },
1651        );
1652    }
1653
1654    fn predictions_enabled_at(
1655        snapshot: &BufferSnapshot,
1656        position: Option<language::Anchor>,
1657        cx: &App,
1658    ) -> bool {
1659        let file = snapshot.file();
1660        let all_settings = all_language_settings(file, cx);
1661        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1662            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1663        {
1664            return false;
1665        }
1666
1667        if let Some(last_position) = position {
1668            let settings = snapshot.settings_at(last_position, cx);
1669
1670            if !settings.edit_predictions_disabled_in.is_empty()
1671                && let Some(scope) = snapshot.language_scope_at(last_position)
1672                && let Some(scope_name) = scope.override_name()
1673                && settings
1674                    .edit_predictions_disabled_in
1675                    .iter()
1676                    .any(|s| s == scope_name)
1677            {
1678                return false;
1679            }
1680        }
1681
1682        true
1683    }
1684
1685    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1686}
1687
1688fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1689    match provider {
1690        EditPredictionProvider::Zed
1691        | EditPredictionProvider::Sweep
1692        | EditPredictionProvider::Mercury
1693        | EditPredictionProvider::Ollama
1694        | EditPredictionProvider::OpenAiCompatibleApi
1695        | EditPredictionProvider::Experimental(_) => true,
1696        EditPredictionProvider::None
1697        | EditPredictionProvider::Copilot
1698        | EditPredictionProvider::Supermaven
1699        | EditPredictionProvider::Codestral => false,
1700    }
1701}
1702
1703impl EditPredictionStore {
1704    fn queue_prediction_refresh(
1705        &mut self,
1706        project: Entity<Project>,
1707        request_trigger: PredictEditsRequestTrigger,
1708        throttle_entity: EntityId,
1709        cx: &mut Context<Self>,
1710        do_refresh: impl FnOnce(
1711            WeakEntity<Self>,
1712            &mut AsyncApp,
1713        )
1714            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1715        + 'static,
1716    ) {
1717        fn select_throttle(
1718            project_state: &mut ProjectState,
1719            request_trigger: PredictEditsRequestTrigger,
1720        ) -> &mut Option<(EntityId, Instant)> {
1721            match request_trigger {
1722                PredictEditsRequestTrigger::Diagnostics => {
1723                    &mut project_state.last_jump_prediction_refresh
1724                }
1725                _ => &mut project_state.last_edit_prediction_refresh,
1726            }
1727        }
1728
1729        let (needs_acceptance_tracking, max_pending_predictions) =
1730            match all_language_settings(None, cx).edit_predictions.provider {
1731                EditPredictionProvider::Zed
1732                | EditPredictionProvider::Sweep
1733                | EditPredictionProvider::Mercury
1734                | EditPredictionProvider::Experimental(_) => (true, 2),
1735                EditPredictionProvider::Ollama => (false, 1),
1736                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1737                EditPredictionProvider::None
1738                | EditPredictionProvider::Copilot
1739                | EditPredictionProvider::Supermaven
1740                | EditPredictionProvider::Codestral => {
1741                    log::error!("queue_prediction_refresh called with non-store provider");
1742                    return;
1743                }
1744            };
1745
1746        let drop_on_cancel = !needs_acceptance_tracking;
1747        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1748        let project_state = self.get_or_init_project(&project, cx);
1749        let pending_prediction_id = project_state.next_pending_prediction_id;
1750        project_state.next_pending_prediction_id += 1;
1751        let last_request = *select_throttle(project_state, request_trigger);
1752
1753        let task = cx.spawn(async move |this, cx| {
1754            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1755                if throttle_entity != last_entity {
1756                    return None;
1757                }
1758                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1759            }) {
1760                cx.background_executor().timer(timeout).await;
1761            }
1762
1763            // If this task was cancelled before the throttle timeout expired,
1764            // do not perform a request.
1765            let mut is_cancelled = true;
1766            this.update(cx, |this, cx| {
1767                let project_state = this.get_or_init_project(&project, cx);
1768                let was_cancelled = project_state
1769                    .cancelled_predictions
1770                    .remove(&pending_prediction_id);
1771                if !was_cancelled {
1772                    let new_refresh = (throttle_entity, Instant::now());
1773                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
1774                    is_cancelled = false;
1775                }
1776            })
1777            .ok();
1778            if is_cancelled {
1779                return None;
1780            }
1781
1782            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1783            let new_prediction_id = new_prediction_result
1784                .as_ref()
1785                .map(|(prediction, _)| prediction.id.clone());
1786
1787            // When a prediction completes, remove it from the pending list, and cancel
1788            // any pending predictions that were enqueued before it.
1789            this.update(cx, |this, cx| {
1790                let project_state = this.get_or_init_project(&project, cx);
1791
1792                let is_cancelled = project_state
1793                    .cancelled_predictions
1794                    .remove(&pending_prediction_id);
1795
1796                let new_current_prediction = if !is_cancelled
1797                    && let Some((prediction_result, requested_by)) = new_prediction_result
1798                {
1799                    match prediction_result.prediction {
1800                        Ok(prediction) => {
1801                            let new_prediction = CurrentEditPrediction {
1802                                requested_by,
1803                                prediction,
1804                                was_shown: false,
1805                                shown_with: None,
1806                            };
1807
1808                            if let Some(current_prediction) =
1809                                project_state.current_prediction.as_ref()
1810                            {
1811                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1812                                {
1813                                    this.reject_current_prediction(
1814                                        EditPredictionRejectReason::Replaced,
1815                                        &project,
1816                                        cx,
1817                                    );
1818
1819                                    Some(new_prediction)
1820                                } else {
1821                                    this.reject_prediction(
1822                                        new_prediction.prediction.id,
1823                                        EditPredictionRejectReason::CurrentPreferred,
1824                                        false,
1825                                        new_prediction.prediction.model_version,
1826                                        cx,
1827                                    );
1828                                    None
1829                                }
1830                            } else {
1831                                Some(new_prediction)
1832                            }
1833                        }
1834                        Err(reject_reason) => {
1835                            this.reject_prediction(
1836                                prediction_result.id,
1837                                reject_reason,
1838                                false,
1839                                None,
1840                                cx,
1841                            );
1842                            None
1843                        }
1844                    }
1845                } else {
1846                    None
1847                };
1848
1849                let project_state = this.get_or_init_project(&project, cx);
1850
1851                if let Some(new_prediction) = new_current_prediction {
1852                    project_state.current_prediction = Some(new_prediction);
1853                }
1854
1855                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1856                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1857                    if pending_prediction.id == pending_prediction_id {
1858                        pending_predictions.remove(ix);
1859                        for pending_prediction in pending_predictions.drain(0..ix) {
1860                            project_state.cancel_pending_prediction(pending_prediction, cx)
1861                        }
1862                        break;
1863                    }
1864                }
1865                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1866                cx.notify();
1867            })
1868            .ok();
1869
1870            new_prediction_id
1871        });
1872
1873        if project_state.pending_predictions.len() < max_pending_predictions {
1874            project_state.pending_predictions.push(PendingPrediction {
1875                id: pending_prediction_id,
1876                task,
1877                drop_on_cancel,
1878            });
1879        } else {
1880            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1881            project_state.pending_predictions.push(PendingPrediction {
1882                id: pending_prediction_id,
1883                task,
1884                drop_on_cancel,
1885            });
1886            project_state.cancel_pending_prediction(pending_prediction, cx);
1887        }
1888    }
1889
1890    pub fn request_prediction(
1891        &mut self,
1892        project: &Entity<Project>,
1893        active_buffer: &Entity<Buffer>,
1894        position: language::Anchor,
1895        trigger: PredictEditsRequestTrigger,
1896        cx: &mut Context<Self>,
1897    ) -> Task<Result<Option<EditPredictionResult>>> {
1898        self.request_prediction_internal(
1899            project.clone(),
1900            active_buffer.clone(),
1901            position,
1902            trigger,
1903            cx.has_flag::<Zeta2FeatureFlag>(),
1904            cx,
1905        )
1906    }
1907
1908    fn request_prediction_internal(
1909        &mut self,
1910        project: Entity<Project>,
1911        active_buffer: Entity<Buffer>,
1912        position: language::Anchor,
1913        trigger: PredictEditsRequestTrigger,
1914        allow_jump: bool,
1915        cx: &mut Context<Self>,
1916    ) -> Task<Result<Option<EditPredictionResult>>> {
1917        self.get_or_init_project(&project, cx);
1918        let project_state = self.projects.get(&project.entity_id()).unwrap();
1919        let stored_events = project_state.events(cx);
1920        let has_events = !stored_events.is_empty();
1921        let events: Vec<Arc<zeta_prompt::Event>> =
1922            stored_events.into_iter().map(|e| e.event).collect();
1923        let debug_tx = project_state.debug_tx.clone();
1924
1925        let snapshot = active_buffer.read(cx).snapshot();
1926        let cursor_point = position.to_point(&snapshot);
1927        let current_offset = position.to_offset(&snapshot);
1928
1929        let mut user_actions: Vec<UserActionRecord> =
1930            project_state.user_actions.iter().cloned().collect();
1931
1932        if let Some(last_action) = user_actions.last() {
1933            if last_action.buffer_id == active_buffer.entity_id()
1934                && current_offset != last_action.offset
1935            {
1936                let timestamp_epoch_ms = SystemTime::now()
1937                    .duration_since(UNIX_EPOCH)
1938                    .map(|d| d.as_millis() as u64)
1939                    .unwrap_or(0);
1940                user_actions.push(UserActionRecord {
1941                    action_type: UserActionType::CursorMovement,
1942                    buffer_id: active_buffer.entity_id(),
1943                    line_number: cursor_point.row,
1944                    offset: current_offset,
1945                    timestamp_epoch_ms,
1946                });
1947            }
1948        }
1949        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1950        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1951        let diagnostic_search_range =
1952            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1953
1954        let related_files = self.context_for_project(&project, cx);
1955
1956        let inputs = EditPredictionModelInput {
1957            project: project.clone(),
1958            buffer: active_buffer,
1959            snapshot: snapshot,
1960            position,
1961            events,
1962            related_files,
1963            recent_paths: project_state.recent_paths.clone(),
1964            trigger,
1965            diagnostic_search_range: diagnostic_search_range,
1966            debug_tx,
1967            user_actions,
1968        };
1969
1970        let task = match self.edit_prediction_model {
1971            EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
1972                self,
1973                inputs,
1974                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1975                cx,
1976            ),
1977            EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
1978                self,
1979                inputs,
1980                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1981                cx,
1982            ),
1983            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
1984            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1985            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1986        };
1987
1988        cx.spawn(async move |this, cx| {
1989            let prediction = task.await?;
1990
1991            if prediction.is_none() && allow_jump && has_events {
1992                this.update(cx, |this, cx| {
1993                    this.refresh_prediction_from_diagnostics(
1994                        project,
1995                        DiagnosticSearchScope::Local,
1996                        cx,
1997                    );
1998                })?;
1999                return anyhow::Ok(None);
2000            }
2001
2002            Ok(prediction)
2003        })
2004    }
2005
2006    pub(crate) async fn next_diagnostic_location(
2007        active_buffer: Entity<Buffer>,
2008        active_buffer_snapshot: &BufferSnapshot,
2009        active_buffer_diagnostic_search_range: Range<Point>,
2010        active_buffer_cursor_point: Point,
2011        project: &Entity<Project>,
2012        cx: &mut AsyncApp,
2013    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2014        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2015            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2016            .flat_map(|(_, _, _, selections)| {
2017                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2018            })
2019            .collect();
2020
2021        let mut jump_location = active_buffer_snapshot
2022            .diagnostic_groups(None)
2023            .into_iter()
2024            .filter_map(|(_, group)| {
2025                let range = &group.entries[group.primary_ix]
2026                    .range
2027                    .to_point(&active_buffer_snapshot);
2028                if range.overlaps(&active_buffer_diagnostic_search_range) {
2029                    return None;
2030                }
2031                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2032                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2033                });
2034                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2035                    <= DIAGNOSTIC_LINES_RANGE;
2036                if near_collaborator && !near_local {
2037                    return None;
2038                }
2039                Some(range.start)
2040            })
2041            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2042            .map(|position| {
2043                (
2044                    active_buffer.clone(),
2045                    active_buffer_snapshot.anchor_before(position),
2046                )
2047            });
2048
2049        if jump_location.is_none() {
2050            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2051                let file = buffer.file()?;
2052
2053                Some(ProjectPath {
2054                    worktree_id: file.worktree_id(cx),
2055                    path: file.path().clone(),
2056                })
2057            });
2058
2059            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2060                project
2061                    .diagnostic_summaries(false, cx)
2062                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2063                    .map(|(path, _, _)| {
2064                        let shared_prefix = path
2065                            .path
2066                            .components()
2067                            .zip(
2068                                active_buffer_path
2069                                    .as_ref()
2070                                    .map(|p| p.path.components())
2071                                    .unwrap_or_default(),
2072                            )
2073                            .take_while(|(a, b)| a == b)
2074                            .count();
2075                        (path, shared_prefix)
2076                    })
2077                    .collect()
2078            });
2079
2080            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2081
2082            for (path, _) in candidates {
2083                let candidate_buffer = project
2084                    .update(cx, |project, cx| project.open_buffer(path, cx))
2085                    .await?;
2086
2087                let (has_collaborators, diagnostic_position) =
2088                    candidate_buffer.read_with(cx, |buffer, _cx| {
2089                        let snapshot = buffer.snapshot();
2090                        let has_collaborators = snapshot
2091                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2092                            .next()
2093                            .is_some();
2094                        let position = buffer
2095                            .buffer_diagnostics(None)
2096                            .into_iter()
2097                            .min_by_key(|entry| entry.diagnostic.severity)
2098                            .map(|entry| entry.range.start);
2099                        (has_collaborators, position)
2100                    });
2101
2102                if has_collaborators {
2103                    continue;
2104                }
2105
2106                if let Some(position) = diagnostic_position {
2107                    jump_location = Some((candidate_buffer, position));
2108                    break;
2109                }
2110            }
2111        }
2112
2113        anyhow::Ok(jump_location)
2114    }
2115
2116    async fn send_raw_llm_request(
2117        request: RawCompletionRequest,
2118        client: Arc<Client>,
2119        custom_url: Option<Arc<Url>>,
2120        llm_token: LlmApiToken,
2121        app_version: Version,
2122    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2123        let url = if let Some(custom_url) = custom_url {
2124            custom_url.as_ref().clone()
2125        } else {
2126            client
2127                .http_client()
2128                .build_zed_llm_url("/predict_edits/raw", &[])?
2129        };
2130
2131        Self::send_api_request(
2132            |builder| {
2133                let req = builder
2134                    .uri(url.as_ref())
2135                    .body(serde_json::to_string(&request)?.into());
2136                Ok(req?)
2137            },
2138            client,
2139            llm_token,
2140            app_version,
2141            true,
2142        )
2143        .await
2144    }
2145
2146    pub(crate) async fn send_v3_request(
2147        input: ZetaPromptInput,
2148        client: Arc<Client>,
2149        llm_token: LlmApiToken,
2150        app_version: Version,
2151        trigger: PredictEditsRequestTrigger,
2152    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2153        let url = client
2154            .http_client()
2155            .build_zed_llm_url("/predict_edits/v3", &[])?;
2156
2157        let request = PredictEditsV3Request { input, trigger };
2158
2159        let json_bytes = serde_json::to_vec(&request)?;
2160        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2161
2162        Self::send_api_request(
2163            |builder| {
2164                let req = builder
2165                    .uri(url.as_ref())
2166                    .header("Content-Encoding", "zstd")
2167                    .body(compressed.clone().into());
2168                Ok(req?)
2169            },
2170            client,
2171            llm_token,
2172            app_version,
2173            true,
2174        )
2175        .await
2176    }
2177
2178    fn handle_api_response<T>(
2179        this: &WeakEntity<Self>,
2180        response: Result<(T, Option<EditPredictionUsage>)>,
2181        cx: &mut gpui::AsyncApp,
2182    ) -> Result<T> {
2183        match response {
2184            Ok((data, usage)) => {
2185                if let Some(usage) = usage {
2186                    this.update(cx, |this, cx| {
2187                        this.user_store.update(cx, |user_store, cx| {
2188                            user_store.update_edit_prediction_usage(usage, cx);
2189                        });
2190                    })
2191                    .ok();
2192                }
2193                Ok(data)
2194            }
2195            Err(err) => {
2196                if err.is::<ZedUpdateRequiredError>() {
2197                    cx.update(|cx| {
2198                        this.update(cx, |this, _cx| {
2199                            this.update_required = true;
2200                        })
2201                        .ok();
2202
2203                        let error_message: SharedString = err.to_string().into();
2204                        show_app_notification(
2205                            NotificationId::unique::<ZedUpdateRequiredError>(),
2206                            cx,
2207                            move |cx| {
2208                                cx.new(|cx| {
2209                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2210                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2211                                })
2212                            },
2213                        );
2214                    });
2215                }
2216                Err(err)
2217            }
2218        }
2219    }
2220
2221    async fn send_api_request<Res>(
2222        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2223        client: Arc<Client>,
2224        llm_token: LlmApiToken,
2225        app_version: Version,
2226        require_auth: bool,
2227    ) -> Result<(Res, Option<EditPredictionUsage>)>
2228    where
2229        Res: DeserializeOwned,
2230    {
2231        let http_client = client.http_client();
2232
2233        let mut token = if require_auth {
2234            Some(llm_token.acquire(&client).await?)
2235        } else {
2236            llm_token.acquire(&client).await.ok()
2237        };
2238        let mut did_retry = false;
2239
2240        loop {
2241            let request_builder = http_client::Request::builder().method(Method::POST);
2242
2243            let mut request_builder = request_builder
2244                .header("Content-Type", "application/json")
2245                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2246
2247            // Only add Authorization header if we have a token
2248            if let Some(ref token_value) = token {
2249                request_builder =
2250                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2251            }
2252
2253            let request = build(request_builder)?;
2254
2255            let mut response = http_client.send(request).await?;
2256
2257            if let Some(minimum_required_version) = response
2258                .headers()
2259                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2260                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2261            {
2262                anyhow::ensure!(
2263                    app_version >= minimum_required_version,
2264                    ZedUpdateRequiredError {
2265                        minimum_version: minimum_required_version
2266                    }
2267                );
2268            }
2269
2270            if response.status().is_success() {
2271                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2272
2273                let mut body = Vec::new();
2274                response.body_mut().read_to_end(&mut body).await?;
2275                return Ok((serde_json::from_slice(&body)?, usage));
2276            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2277                did_retry = true;
2278                token = Some(llm_token.refresh(&client).await?);
2279            } else {
2280                let mut body = String::new();
2281                response.body_mut().read_to_string(&mut body).await?;
2282                anyhow::bail!(
2283                    "Request failed with status: {:?}\nBody: {}",
2284                    response.status(),
2285                    body
2286                );
2287            }
2288        }
2289    }
2290
2291    pub fn refresh_context(
2292        &mut self,
2293        project: &Entity<Project>,
2294        buffer: &Entity<language::Buffer>,
2295        cursor_position: language::Anchor,
2296        cx: &mut Context<Self>,
2297    ) {
2298        self.get_or_init_project(project, cx)
2299            .context
2300            .update(cx, |store, cx| {
2301                store.refresh(buffer.clone(), cursor_position, cx);
2302            });
2303    }
2304
2305    #[cfg(feature = "cli-support")]
2306    pub fn set_context_for_buffer(
2307        &mut self,
2308        project: &Entity<Project>,
2309        related_files: Vec<RelatedFile>,
2310        cx: &mut Context<Self>,
2311    ) {
2312        self.get_or_init_project(project, cx)
2313            .context
2314            .update(cx, |store, cx| {
2315                store.set_related_files(related_files, cx);
2316            });
2317    }
2318
2319    #[cfg(feature = "cli-support")]
2320    pub fn set_recent_paths_for_project(
2321        &mut self,
2322        project: &Entity<Project>,
2323        paths: impl IntoIterator<Item = project::ProjectPath>,
2324        cx: &mut Context<Self>,
2325    ) {
2326        let project_state = self.get_or_init_project(project, cx);
2327        project_state.recent_paths = paths.into_iter().collect();
2328    }
2329
2330    fn is_file_open_source(
2331        &self,
2332        project: &Entity<Project>,
2333        file: &Arc<dyn File>,
2334        cx: &App,
2335    ) -> bool {
2336        if !file.is_local() || file.is_private() {
2337            return false;
2338        }
2339        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2340            return false;
2341        };
2342        project_state
2343            .license_detection_watchers
2344            .get(&file.worktree_id(cx))
2345            .as_ref()
2346            .is_some_and(|watcher| watcher.is_project_open_source())
2347    }
2348
2349    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2350        self.data_collection_choice.is_enabled(cx)
2351    }
2352
2353    fn load_data_collection_choice() -> DataCollectionChoice {
2354        let choice = KEY_VALUE_STORE
2355            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2356            .log_err()
2357            .flatten();
2358
2359        match choice.as_deref() {
2360            Some("true") => DataCollectionChoice::Enabled,
2361            Some("false") => DataCollectionChoice::Disabled,
2362            Some(_) => {
2363                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2364                DataCollectionChoice::NotAnswered
2365            }
2366            None => DataCollectionChoice::NotAnswered,
2367        }
2368    }
2369
2370    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2371        self.data_collection_choice = self.data_collection_choice.toggle();
2372        let new_choice = self.data_collection_choice;
2373        let is_enabled = new_choice.is_enabled(cx);
2374        db::write_and_log(cx, move || {
2375            KEY_VALUE_STORE.write_kvp(
2376                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2377                is_enabled.to_string(),
2378            )
2379        });
2380    }
2381
2382    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2383        self.shown_predictions.iter()
2384    }
2385
2386    pub fn shown_completions_len(&self) -> usize {
2387        self.shown_predictions.len()
2388    }
2389
2390    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2391        self.rated_predictions.contains(id)
2392    }
2393
2394    pub fn rate_prediction(
2395        &mut self,
2396        prediction: &EditPrediction,
2397        rating: EditPredictionRating,
2398        feedback: String,
2399        cx: &mut Context<Self>,
2400    ) {
2401        let organization = self.user_store.read(cx).current_organization();
2402
2403        self.rated_predictions.insert(prediction.id.clone());
2404
2405        cx.background_spawn({
2406            let client = self.client.clone();
2407            let prediction_id = prediction.id.to_string();
2408            let inputs = serde_json::to_value(&prediction.inputs);
2409            let output = prediction
2410                .edit_preview
2411                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2412            async move {
2413                client
2414                    .cloud_client()
2415                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2416                        organization_id: organization.map(|organization| organization.id.clone()),
2417                        request_id: prediction_id,
2418                        rating: match rating {
2419                            EditPredictionRating::Positive => "positive".to_string(),
2420                            EditPredictionRating::Negative => "negative".to_string(),
2421                        },
2422                        inputs: inputs?,
2423                        output,
2424                        feedback,
2425                    })
2426                    .await?;
2427
2428                anyhow::Ok(())
2429            }
2430        })
2431        .detach_and_log_err(cx);
2432
2433        cx.notify();
2434    }
2435}
2436
2437fn merge_trailing_events_if_needed(
2438    events: &mut VecDeque<StoredEvent>,
2439    end_snapshot: &TextBufferSnapshot,
2440    latest_snapshot: &TextBufferSnapshot,
2441    latest_edit_range: &Range<Anchor>,
2442) {
2443    if let Some(last_event) = events.back() {
2444        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2445            return;
2446        }
2447    }
2448
2449    let mut next_old_event = None;
2450    let mut mergeable_count = 0;
2451    for old_event in events.iter().rev() {
2452        if let Some(next_old_event) = &next_old_event
2453            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2454        {
2455            break;
2456        }
2457        mergeable_count += 1;
2458        next_old_event = Some(old_event);
2459    }
2460
2461    if mergeable_count <= 1 {
2462        return;
2463    }
2464
2465    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2466    let oldest_event = events_to_merge.peek().unwrap();
2467    let oldest_snapshot = oldest_event.old_snapshot.clone();
2468
2469    if let Some((diff, edited_range)) =
2470        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2471    {
2472        let merged_event = match oldest_event.event.as_ref() {
2473            zeta_prompt::Event::BufferChange {
2474                old_path,
2475                path,
2476                in_open_source_repo,
2477                ..
2478            } => StoredEvent {
2479                event: Arc::new(zeta_prompt::Event::BufferChange {
2480                    old_path: old_path.clone(),
2481                    path: path.clone(),
2482                    diff,
2483                    in_open_source_repo: *in_open_source_repo,
2484                    predicted: events_to_merge.all(|e| {
2485                        matches!(
2486                            e.event.as_ref(),
2487                            zeta_prompt::Event::BufferChange {
2488                                predicted: true,
2489                                ..
2490                            }
2491                        )
2492                    }),
2493                }),
2494                old_snapshot: oldest_snapshot.clone(),
2495                edit_range: end_snapshot.anchor_before(edited_range.start)
2496                    ..end_snapshot.anchor_before(edited_range.end),
2497            },
2498        };
2499        events.truncate(events.len() - mergeable_count);
2500        events.push_back(merged_event);
2501    }
2502}
2503
2504pub(crate) fn filter_redundant_excerpts(
2505    mut related_files: Vec<RelatedFile>,
2506    cursor_path: &Path,
2507    cursor_row_range: Range<u32>,
2508) -> Vec<RelatedFile> {
2509    for file in &mut related_files {
2510        if file.path.as_ref() == cursor_path {
2511            file.excerpts.retain(|excerpt| {
2512                excerpt.row_range.start < cursor_row_range.start
2513                    || excerpt.row_range.end > cursor_row_range.end
2514            });
2515        }
2516    }
2517    related_files.retain(|file| !file.excerpts.is_empty());
2518    related_files
2519}
2520
2521#[derive(Error, Debug)]
2522#[error(
2523    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2524)]
2525pub struct ZedUpdateRequiredError {
2526    minimum_version: Version,
2527}
2528
2529#[derive(Debug, Clone, Copy)]
2530pub enum DataCollectionChoice {
2531    NotAnswered,
2532    Enabled,
2533    Disabled,
2534}
2535
2536impl DataCollectionChoice {
2537    pub fn is_enabled(self, cx: &App) -> bool {
2538        if cx.is_staff() {
2539            return true;
2540        }
2541        match self {
2542            Self::Enabled => true,
2543            Self::NotAnswered | Self::Disabled => false,
2544        }
2545    }
2546
2547    #[must_use]
2548    pub fn toggle(&self) -> DataCollectionChoice {
2549        match self {
2550            Self::Enabled => Self::Disabled,
2551            Self::Disabled => Self::Enabled,
2552            Self::NotAnswered => Self::Enabled,
2553        }
2554    }
2555}
2556
2557impl From<bool> for DataCollectionChoice {
2558    fn from(value: bool) -> Self {
2559        match value {
2560            true => DataCollectionChoice::Enabled,
2561            false => DataCollectionChoice::Disabled,
2562        }
2563    }
2564}
2565
2566struct ZedPredictUpsell;
2567
2568impl Dismissable for ZedPredictUpsell {
2569    const KEY: &'static str = "dismissed-edit-predict-upsell";
2570
2571    fn dismissed() -> bool {
2572        // To make this backwards compatible with older versions of Zed, we
2573        // check if the user has seen the previous Edit Prediction Onboarding
2574        // before, by checking the data collection choice which was written to
2575        // the database once the user clicked on "Accept and Enable"
2576        if KEY_VALUE_STORE
2577            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2578            .log_err()
2579            .is_some_and(|s| s.is_some())
2580        {
2581            return true;
2582        }
2583
2584        KEY_VALUE_STORE
2585            .read_kvp(Self::KEY)
2586            .log_err()
2587            .is_some_and(|s| s.is_some())
2588    }
2589}
2590
2591pub fn should_show_upsell_modal() -> bool {
2592    !ZedPredictUpsell::dismissed()
2593}
2594
2595pub fn init(cx: &mut App) {
2596    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2597        workspace.register_action(
2598            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2599                ZedPredictModal::toggle(
2600                    workspace,
2601                    workspace.user_store().clone(),
2602                    workspace.client().clone(),
2603                    window,
2604                    cx,
2605                )
2606            },
2607        );
2608
2609        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2610            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2611                settings
2612                    .project
2613                    .all_languages
2614                    .edit_predictions
2615                    .get_or_insert_default()
2616                    .provider = Some(EditPredictionProvider::None)
2617            });
2618        });
2619        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2620            EditPredictionStore::try_global(cx).and_then(|store| {
2621                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2622            })
2623        }
2624
2625        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2626            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2627                copilot_ui::initiate_sign_in(copilot, window, cx);
2628            }
2629        });
2630        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2631            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2632                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2633            }
2634        });
2635        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2636            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2637                copilot_ui::initiate_sign_out(copilot, window, cx);
2638            }
2639        });
2640    })
2641    .detach();
2642}