edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::Edit;
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::onboarding_modal::ZedPredictModal;
  81pub use crate::prediction::EditPrediction;
  82pub use crate::prediction::EditPredictionId;
  83use crate::prediction::EditPredictionResult;
  84pub use crate::sweep_ai::SweepAi;
  85pub use capture_example::capture_example;
  86pub use language_model::ApiKeyState;
  87pub use telemetry_events::EditPredictionRating;
  88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  89
  90actions!(
  91    edit_prediction,
  92    [
  93        /// Resets the edit prediction onboarding state.
  94        ResetOnboarding,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106
 107pub struct Zeta2FeatureFlag;
 108pub struct EditPredictionJumpsFeatureFlag;
 109
 110impl FeatureFlag for Zeta2FeatureFlag {
 111    const NAME: &'static str = "zeta2";
 112}
 113
 114impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 115    const NAME: &'static str = "edit_prediction_jumps";
 116}
 117
 118#[derive(Clone)]
 119struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 120
 121impl Global for EditPredictionStoreGlobal {}
 122
 123/// Configuration for using the raw Zeta2 endpoint.
 124/// When set, the client uses the raw endpoint and constructs the prompt itself.
 125/// The version is also used as the Baseten environment name (lowercased).
 126#[derive(Clone)]
 127pub struct Zeta2RawConfig {
 128    pub model_id: Option<String>,
 129    pub format: ZetaFormat,
 130}
 131
 132pub struct EditPredictionStore {
 133    client: Arc<Client>,
 134    user_store: Entity<UserStore>,
 135    llm_token: LlmApiToken,
 136    _llm_token_subscription: Subscription,
 137    projects: HashMap<EntityId, ProjectState>,
 138    update_required: bool,
 139    edit_prediction_model: EditPredictionModel,
 140    zeta2_raw_config: Option<Zeta2RawConfig>,
 141    pub sweep_ai: SweepAi,
 142    pub mercury: Mercury,
 143    data_collection_choice: DataCollectionChoice,
 144    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 145    shown_predictions: VecDeque<EditPrediction>,
 146    rated_predictions: HashSet<EditPredictionId>,
 147}
 148
 149#[derive(Copy, Clone, PartialEq, Eq)]
 150pub enum EditPredictionModel {
 151    Zeta1,
 152    Zeta2,
 153    Fim { format: EditPredictionPromptFormat },
 154    Sweep,
 155    Mercury,
 156}
 157
 158#[derive(Clone)]
 159pub struct EditPredictionModelInput {
 160    project: Entity<Project>,
 161    buffer: Entity<Buffer>,
 162    snapshot: BufferSnapshot,
 163    position: Anchor,
 164    events: Vec<Arc<zeta_prompt::Event>>,
 165    related_files: Vec<RelatedFile>,
 166    recent_paths: VecDeque<ProjectPath>,
 167    trigger: PredictEditsRequestTrigger,
 168    diagnostic_search_range: Range<Point>,
 169    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 170    pub user_actions: Vec<UserActionRecord>,
 171}
 172
 173#[derive(Debug)]
 174pub enum DebugEvent {
 175    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 176    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 177    EditPredictionStarted(EditPredictionStartedDebugEvent),
 178    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 179}
 180
 181#[derive(Debug)]
 182pub struct ContextRetrievalStartedDebugEvent {
 183    pub project_entity_id: EntityId,
 184    pub timestamp: Instant,
 185    pub search_prompt: String,
 186}
 187
 188#[derive(Debug)]
 189pub struct ContextRetrievalFinishedDebugEvent {
 190    pub project_entity_id: EntityId,
 191    pub timestamp: Instant,
 192    pub metadata: Vec<(&'static str, SharedString)>,
 193}
 194
 195#[derive(Debug)]
 196pub struct EditPredictionStartedDebugEvent {
 197    pub buffer: WeakEntity<Buffer>,
 198    pub position: Anchor,
 199    pub prompt: Option<String>,
 200}
 201
 202#[derive(Debug)]
 203pub struct EditPredictionFinishedDebugEvent {
 204    pub buffer: WeakEntity<Buffer>,
 205    pub position: Anchor,
 206    pub model_output: Option<String>,
 207}
 208
 209const USER_ACTION_HISTORY_SIZE: usize = 16;
 210
 211#[derive(Clone, Debug)]
 212pub struct UserActionRecord {
 213    pub action_type: UserActionType,
 214    pub buffer_id: EntityId,
 215    pub line_number: u32,
 216    pub offset: usize,
 217    pub timestamp_epoch_ms: u64,
 218}
 219
 220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 221pub enum UserActionType {
 222    InsertChar,
 223    InsertSelection,
 224    DeleteChar,
 225    DeleteSelection,
 226    CursorMovement,
 227}
 228
 229/// An event with associated metadata for reconstructing buffer state.
 230#[derive(Clone)]
 231pub struct StoredEvent {
 232    pub event: Arc<zeta_prompt::Event>,
 233    pub old_snapshot: TextBufferSnapshot,
 234    pub edit_range: Range<Anchor>,
 235}
 236
 237impl StoredEvent {
 238    fn can_merge(
 239        &self,
 240        next_old_event: &&&StoredEvent,
 241        new_snapshot: &TextBufferSnapshot,
 242        last_edit_range: &Range<Anchor>,
 243    ) -> bool {
 244        // Events must be for the same buffer
 245        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 246            return false;
 247        }
 248        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 249            return false;
 250        }
 251
 252        let a_is_predicted = matches!(
 253            self.event.as_ref(),
 254            zeta_prompt::Event::BufferChange {
 255                predicted: true,
 256                ..
 257            }
 258        );
 259        let b_is_predicted = matches!(
 260            next_old_event.event.as_ref(),
 261            zeta_prompt::Event::BufferChange {
 262                predicted: true,
 263                ..
 264            }
 265        );
 266
 267        // If events come from the same source (both predicted or both manual) then
 268        // we would have coalesced them already.
 269        if a_is_predicted == b_is_predicted {
 270            return false;
 271        }
 272
 273        let left_range = self.edit_range.to_point(new_snapshot);
 274        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 275        let latest_range = last_edit_range.to_point(&new_snapshot);
 276
 277        // Events near to the latest edit are not merged if their sources differ.
 278        if lines_between_ranges(&left_range, &latest_range)
 279            .min(lines_between_ranges(&right_range, &latest_range))
 280            <= CHANGE_GROUPING_LINE_SPAN
 281        {
 282            return false;
 283        }
 284
 285        // Events that are distant from each other are not merged.
 286        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 287            return false;
 288        }
 289
 290        true
 291    }
 292}
 293
 294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 295    if left.start > right.end {
 296        return left.start.row - right.end.row;
 297    }
 298    if right.start > left.end {
 299        return right.start.row - left.end.row;
 300    }
 301    0
 302}
 303
 304struct ProjectState {
 305    events: VecDeque<StoredEvent>,
 306    last_event: Option<LastEvent>,
 307    recent_paths: VecDeque<ProjectPath>,
 308    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 309    current_prediction: Option<CurrentEditPrediction>,
 310    next_pending_prediction_id: usize,
 311    pending_predictions: ArrayVec<PendingPrediction, 2>,
 312    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 313    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 314    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 315    cancelled_predictions: HashSet<usize>,
 316    context: Entity<RelatedExcerptStore>,
 317    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 318    user_actions: VecDeque<UserActionRecord>,
 319    _subscriptions: [gpui::Subscription; 2],
 320    copilot: Option<Entity<Copilot>>,
 321}
 322
 323impl ProjectState {
 324    fn record_user_action(&mut self, action: UserActionRecord) {
 325        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 326            self.user_actions.pop_front();
 327        }
 328        self.user_actions.push_back(action);
 329    }
 330
 331    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 332        self.events
 333            .iter()
 334            .cloned()
 335            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 336                let (one, two) = event.split_by_pause();
 337                let one = one.finalize(&self.license_detection_watchers, cx);
 338                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 339                one.into_iter().chain(two)
 340            }))
 341            .collect()
 342    }
 343
 344    fn cancel_pending_prediction(
 345        &mut self,
 346        pending_prediction: PendingPrediction,
 347        cx: &mut Context<EditPredictionStore>,
 348    ) {
 349        self.cancelled_predictions.insert(pending_prediction.id);
 350
 351        if pending_prediction.drop_on_cancel {
 352            drop(pending_prediction.task);
 353        } else {
 354            cx.spawn(async move |this, cx| {
 355                let Some(prediction_id) = pending_prediction.task.await else {
 356                    return;
 357                };
 358
 359                this.update(cx, |this, cx| {
 360                    this.reject_prediction(
 361                        prediction_id,
 362                        EditPredictionRejectReason::Canceled,
 363                        false,
 364                        cx,
 365                    );
 366                })
 367                .ok();
 368            })
 369            .detach()
 370        }
 371    }
 372
 373    fn active_buffer(
 374        &self,
 375        project: &Entity<Project>,
 376        cx: &App,
 377    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 378        let project = project.read(cx);
 379        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 380        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 381        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 382        Some((active_buffer, registered_buffer.last_position))
 383    }
 384}
 385
 386#[derive(Debug, Clone)]
 387struct CurrentEditPrediction {
 388    pub requested_by: PredictionRequestedBy,
 389    pub prediction: EditPrediction,
 390    pub was_shown: bool,
 391    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 392}
 393
 394impl CurrentEditPrediction {
 395    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 396        let Some(new_edits) = self
 397            .prediction
 398            .interpolate(&self.prediction.buffer.read(cx))
 399        else {
 400            return false;
 401        };
 402
 403        if self.prediction.buffer != old_prediction.prediction.buffer {
 404            return true;
 405        }
 406
 407        let Some(old_edits) = old_prediction
 408            .prediction
 409            .interpolate(&old_prediction.prediction.buffer.read(cx))
 410        else {
 411            return true;
 412        };
 413
 414        let requested_by_buffer_id = self.requested_by.buffer_id();
 415
 416        // This reduces the occurrence of UI thrash from replacing edits
 417        //
 418        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 419        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 420            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 421            && old_edits.len() == 1
 422            && new_edits.len() == 1
 423        {
 424            let (old_range, old_text) = &old_edits[0];
 425            let (new_range, new_text) = &new_edits[0];
 426            new_range == old_range && new_text.starts_with(old_text.as_ref())
 427        } else {
 428            true
 429        }
 430    }
 431}
 432
 433#[derive(Debug, Clone)]
 434enum PredictionRequestedBy {
 435    DiagnosticsUpdate,
 436    Buffer(EntityId),
 437}
 438
 439impl PredictionRequestedBy {
 440    pub fn buffer_id(&self) -> Option<EntityId> {
 441        match self {
 442            PredictionRequestedBy::DiagnosticsUpdate => None,
 443            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 444        }
 445    }
 446}
 447
 448const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 449
 450#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 451pub enum DiagnosticSearchScope {
 452    Local,
 453    Global,
 454}
 455
 456#[derive(Debug)]
 457struct PendingPrediction {
 458    id: usize,
 459    task: Task<Option<EditPredictionId>>,
 460    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 461    /// If false, the task is awaited to completion so rejection can be reported.
 462    drop_on_cancel: bool,
 463}
 464
 465/// A prediction from the perspective of a buffer.
 466#[derive(Debug)]
 467enum BufferEditPrediction<'a> {
 468    Local { prediction: &'a EditPrediction },
 469    Jump { prediction: &'a EditPrediction },
 470}
 471
 472#[cfg(test)]
 473impl std::ops::Deref for BufferEditPrediction<'_> {
 474    type Target = EditPrediction;
 475
 476    fn deref(&self) -> &Self::Target {
 477        match self {
 478            BufferEditPrediction::Local { prediction } => prediction,
 479            BufferEditPrediction::Jump { prediction } => prediction,
 480        }
 481    }
 482}
 483
 484struct RegisteredBuffer {
 485    file: Option<Arc<dyn File>>,
 486    snapshot: TextBufferSnapshot,
 487    last_position: Option<Anchor>,
 488    _subscriptions: [gpui::Subscription; 2],
 489}
 490
 491#[derive(Clone)]
 492struct LastEvent {
 493    old_snapshot: TextBufferSnapshot,
 494    new_snapshot: TextBufferSnapshot,
 495    old_file: Option<Arc<dyn File>>,
 496    new_file: Option<Arc<dyn File>>,
 497    edit_range: Option<Range<Anchor>>,
 498    predicted: bool,
 499    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 500    last_edit_time: Option<Instant>,
 501}
 502
 503impl LastEvent {
 504    pub fn finalize(
 505        &self,
 506        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 507        cx: &App,
 508    ) -> Option<StoredEvent> {
 509        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 510        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 511
 512        let in_open_source_repo =
 513            [self.new_file.as_ref(), self.old_file.as_ref()]
 514                .iter()
 515                .all(|file| {
 516                    file.is_some_and(|file| {
 517                        license_detection_watchers
 518                            .get(&file.worktree_id(cx))
 519                            .is_some_and(|watcher| watcher.is_project_open_source())
 520                    })
 521                });
 522
 523        let (diff, edit_range) =
 524            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 525
 526        if path == old_path && diff.is_empty() {
 527            None
 528        } else {
 529            Some(StoredEvent {
 530                event: Arc::new(zeta_prompt::Event::BufferChange {
 531                    old_path,
 532                    path,
 533                    diff,
 534                    in_open_source_repo,
 535                    predicted: self.predicted,
 536                }),
 537                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 538                    ..self.new_snapshot.anchor_before(edit_range.end),
 539                old_snapshot: self.old_snapshot.clone(),
 540            })
 541        }
 542    }
 543
 544    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 545        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 546            return (self.clone(), None);
 547        };
 548
 549        let before = LastEvent {
 550            old_snapshot: self.old_snapshot.clone(),
 551            new_snapshot: boundary_snapshot.clone(),
 552            old_file: self.old_file.clone(),
 553            new_file: self.new_file.clone(),
 554            edit_range: None,
 555            predicted: self.predicted,
 556            snapshot_after_last_editing_pause: None,
 557            last_edit_time: self.last_edit_time,
 558        };
 559
 560        let after = LastEvent {
 561            old_snapshot: boundary_snapshot.clone(),
 562            new_snapshot: self.new_snapshot.clone(),
 563            old_file: self.old_file.clone(),
 564            new_file: self.new_file.clone(),
 565            edit_range: None,
 566            predicted: self.predicted,
 567            snapshot_after_last_editing_pause: None,
 568            last_edit_time: self.last_edit_time,
 569        };
 570
 571        (before, Some(after))
 572    }
 573}
 574
 575pub(crate) fn compute_diff_between_snapshots(
 576    old_snapshot: &TextBufferSnapshot,
 577    new_snapshot: &TextBufferSnapshot,
 578) -> Option<(String, Range<Point>)> {
 579    let edits: Vec<Edit<usize>> = new_snapshot
 580        .edits_since::<usize>(&old_snapshot.version)
 581        .collect();
 582
 583    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 584
 585    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 586    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 587    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 588    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 589
 590    const CONTEXT_LINES: u32 = 3;
 591
 592    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 593    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 594    let old_context_end_row =
 595        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 596    let new_context_end_row =
 597        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 598
 599    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 600    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 601    let old_end_line_offset = old_snapshot
 602        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 603    let new_end_line_offset = new_snapshot
 604        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 605    let old_edit_range = old_start_line_offset..old_end_line_offset;
 606    let new_edit_range = new_start_line_offset..new_end_line_offset;
 607
 608    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 609    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 610
 611    let diff = language::unified_diff_with_offsets(
 612        &old_region_text,
 613        &new_region_text,
 614        old_context_start_row,
 615        new_context_start_row,
 616    );
 617
 618    Some((diff, new_start_point..new_end_point))
 619}
 620
 621fn buffer_path_with_id_fallback(
 622    file: Option<&Arc<dyn File>>,
 623    snapshot: &TextBufferSnapshot,
 624    cx: &App,
 625) -> Arc<Path> {
 626    if let Some(file) = file {
 627        file.full_path(cx).into()
 628    } else {
 629        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 630    }
 631}
 632
 633impl EditPredictionStore {
 634    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 635        cx.try_global::<EditPredictionStoreGlobal>()
 636            .map(|global| global.0.clone())
 637    }
 638
 639    pub fn global(
 640        client: &Arc<Client>,
 641        user_store: &Entity<UserStore>,
 642        cx: &mut App,
 643    ) -> Entity<Self> {
 644        cx.try_global::<EditPredictionStoreGlobal>()
 645            .map(|global| global.0.clone())
 646            .unwrap_or_else(|| {
 647                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 648                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 649                ep_store
 650            })
 651    }
 652
 653    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 654        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 655        let data_collection_choice = Self::load_data_collection_choice();
 656
 657        let llm_token = LlmApiToken::default();
 658
 659        let (reject_tx, reject_rx) = mpsc::unbounded();
 660        cx.background_spawn({
 661            let client = client.clone();
 662            let llm_token = llm_token.clone();
 663            let app_version = AppVersion::global(cx);
 664            let background_executor = cx.background_executor().clone();
 665            async move {
 666                Self::handle_rejected_predictions(
 667                    reject_rx,
 668                    client,
 669                    llm_token,
 670                    app_version,
 671                    background_executor,
 672                )
 673                .await
 674            }
 675        })
 676        .detach();
 677
 678        let this = Self {
 679            projects: HashMap::default(),
 680            client,
 681            user_store,
 682            llm_token,
 683            _llm_token_subscription: cx.subscribe(
 684                &refresh_llm_token_listener,
 685                |this, _listener, _event, cx| {
 686                    let client = this.client.clone();
 687                    let llm_token = this.llm_token.clone();
 688                    cx.spawn(async move |_this, _cx| {
 689                        llm_token.refresh(&client).await?;
 690                        anyhow::Ok(())
 691                    })
 692                    .detach_and_log_err(cx);
 693                },
 694            ),
 695            update_required: false,
 696            edit_prediction_model: EditPredictionModel::Zeta2,
 697            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 698            sweep_ai: SweepAi::new(cx),
 699            mercury: Mercury::new(cx),
 700
 701            data_collection_choice,
 702            reject_predictions_tx: reject_tx,
 703            rated_predictions: Default::default(),
 704            shown_predictions: Default::default(),
 705        };
 706
 707        this
 708    }
 709
 710    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 711        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 712        let format = ZetaFormat::parse(&version_str).ok()?;
 713        let model_id = env::var("ZED_ZETA_MODEL").ok();
 714        Some(Zeta2RawConfig { model_id, format })
 715    }
 716
 717    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 718        self.edit_prediction_model = model;
 719    }
 720
 721    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 722        self.zeta2_raw_config = Some(config);
 723    }
 724
 725    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 726        self.zeta2_raw_config.as_ref()
 727    }
 728
 729    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 730        use ui::IconName;
 731        match self.edit_prediction_model {
 732            EditPredictionModel::Sweep => {
 733                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 734                    .with_disabled(IconName::SweepAiDisabled)
 735                    .with_up(IconName::SweepAiUp)
 736                    .with_down(IconName::SweepAiDown)
 737                    .with_error(IconName::SweepAiError)
 738            }
 739            EditPredictionModel::Mercury => {
 740                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 741            }
 742            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 743                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 744                    .with_disabled(IconName::ZedPredictDisabled)
 745                    .with_up(IconName::ZedPredictUp)
 746                    .with_down(IconName::ZedPredictDown)
 747                    .with_error(IconName::ZedPredictError)
 748            }
 749            EditPredictionModel::Fim { .. } => {
 750                let settings = &all_language_settings(None, cx).edit_predictions;
 751                match settings.provider {
 752                    EditPredictionProvider::Ollama => {
 753                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 754                    }
 755                    _ => {
 756                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 757                    }
 758                }
 759            }
 760        }
 761    }
 762
 763    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 764        self.sweep_ai.api_token.read(cx).has_key()
 765    }
 766
 767    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 768        self.mercury.api_token.read(cx).has_key()
 769    }
 770
 771    pub fn clear_history(&mut self) {
 772        for project_state in self.projects.values_mut() {
 773            project_state.events.clear();
 774            project_state.last_event.take();
 775        }
 776    }
 777
 778    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 779        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 780            project_state.events.clear();
 781            project_state.last_event.take();
 782        }
 783    }
 784
 785    pub fn edit_history_for_project(
 786        &self,
 787        project: &Entity<Project>,
 788        cx: &App,
 789    ) -> Vec<StoredEvent> {
 790        self.projects
 791            .get(&project.entity_id())
 792            .map(|project_state| project_state.events(cx))
 793            .unwrap_or_default()
 794    }
 795
 796    pub fn context_for_project<'a>(
 797        &'a self,
 798        project: &Entity<Project>,
 799        cx: &'a mut App,
 800    ) -> Vec<RelatedFile> {
 801        self.projects
 802            .get(&project.entity_id())
 803            .map(|project_state| {
 804                project_state.context.update(cx, |context, cx| {
 805                    context
 806                        .related_files_with_buffers(cx)
 807                        .map(|(mut related_file, buffer)| {
 808                            related_file.in_open_source_repo = buffer
 809                                .read(cx)
 810                                .file()
 811                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 812                            related_file
 813                        })
 814                        .collect()
 815                })
 816            })
 817            .unwrap_or_default()
 818    }
 819
 820    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 821        self.projects
 822            .get(&project.entity_id())
 823            .and_then(|project| project.copilot.clone())
 824    }
 825
 826    pub fn start_copilot_for_project(
 827        &mut self,
 828        project: &Entity<Project>,
 829        cx: &mut Context<Self>,
 830    ) -> Option<Entity<Copilot>> {
 831        if DisableAiSettings::get(None, cx).disable_ai {
 832            return None;
 833        }
 834        let state = self.get_or_init_project(project, cx);
 835
 836        if state.copilot.is_some() {
 837            return state.copilot.clone();
 838        }
 839        let _project = project.clone();
 840        let project = project.read(cx);
 841
 842        let node = project.node_runtime().cloned();
 843        if let Some(node) = node {
 844            let next_id = project.languages().next_language_server_id();
 845            let fs = project.fs().clone();
 846
 847            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 848            state.copilot = Some(copilot.clone());
 849            Some(copilot)
 850        } else {
 851            None
 852        }
 853    }
 854
 855    pub fn context_for_project_with_buffers<'a>(
 856        &'a self,
 857        project: &Entity<Project>,
 858        cx: &'a mut App,
 859    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 860        self.projects
 861            .get(&project.entity_id())
 862            .map(|project| {
 863                project.context.update(cx, |context, cx| {
 864                    context.related_files_with_buffers(cx).collect()
 865                })
 866            })
 867            .unwrap_or_default()
 868    }
 869
 870    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 871        if matches!(
 872            self.edit_prediction_model,
 873            EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
 874        ) {
 875            self.user_store.read(cx).edit_prediction_usage()
 876        } else {
 877            None
 878        }
 879    }
 880
 881    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 882        self.get_or_init_project(project, cx);
 883    }
 884
 885    pub fn register_buffer(
 886        &mut self,
 887        buffer: &Entity<Buffer>,
 888        project: &Entity<Project>,
 889        cx: &mut Context<Self>,
 890    ) {
 891        let project_state = self.get_or_init_project(project, cx);
 892        Self::register_buffer_impl(project_state, buffer, project, cx);
 893    }
 894
 895    fn get_or_init_project(
 896        &mut self,
 897        project: &Entity<Project>,
 898        cx: &mut Context<Self>,
 899    ) -> &mut ProjectState {
 900        let entity_id = project.entity_id();
 901        self.projects
 902            .entry(entity_id)
 903            .or_insert_with(|| ProjectState {
 904                context: {
 905                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 906                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 907                        this.handle_excerpt_store_event(entity_id, event);
 908                    })
 909                    .detach();
 910                    related_excerpt_store
 911                },
 912                events: VecDeque::new(),
 913                last_event: None,
 914                recent_paths: VecDeque::new(),
 915                debug_tx: None,
 916                registered_buffers: HashMap::default(),
 917                current_prediction: None,
 918                cancelled_predictions: HashSet::default(),
 919                pending_predictions: ArrayVec::new(),
 920                next_pending_prediction_id: 0,
 921                last_edit_prediction_refresh: None,
 922                last_jump_prediction_refresh: None,
 923                license_detection_watchers: HashMap::default(),
 924                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 925                _subscriptions: [
 926                    cx.subscribe(&project, Self::handle_project_event),
 927                    cx.observe_release(&project, move |this, _, cx| {
 928                        this.projects.remove(&entity_id);
 929                        cx.notify();
 930                    }),
 931                ],
 932                copilot: None,
 933            })
 934    }
 935
 936    pub fn remove_project(&mut self, project: &Entity<Project>) {
 937        self.projects.remove(&project.entity_id());
 938    }
 939
 940    fn handle_excerpt_store_event(
 941        &mut self,
 942        project_entity_id: EntityId,
 943        event: &RelatedExcerptStoreEvent,
 944    ) {
 945        if let Some(project_state) = self.projects.get(&project_entity_id) {
 946            if let Some(debug_tx) = project_state.debug_tx.clone() {
 947                match event {
 948                    RelatedExcerptStoreEvent::StartedRefresh => {
 949                        debug_tx
 950                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 951                                ContextRetrievalStartedDebugEvent {
 952                                    project_entity_id: project_entity_id,
 953                                    timestamp: Instant::now(),
 954                                    search_prompt: String::new(),
 955                                },
 956                            ))
 957                            .ok();
 958                    }
 959                    RelatedExcerptStoreEvent::FinishedRefresh {
 960                        cache_hit_count,
 961                        cache_miss_count,
 962                        mean_definition_latency,
 963                        max_definition_latency,
 964                    } => {
 965                        debug_tx
 966                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 967                                ContextRetrievalFinishedDebugEvent {
 968                                    project_entity_id: project_entity_id,
 969                                    timestamp: Instant::now(),
 970                                    metadata: vec![
 971                                        (
 972                                            "Cache Hits",
 973                                            format!(
 974                                                "{}/{}",
 975                                                cache_hit_count,
 976                                                cache_hit_count + cache_miss_count
 977                                            )
 978                                            .into(),
 979                                        ),
 980                                        (
 981                                            "Max LSP Time",
 982                                            format!("{} ms", max_definition_latency.as_millis())
 983                                                .into(),
 984                                        ),
 985                                        (
 986                                            "Mean LSP Time",
 987                                            format!("{} ms", mean_definition_latency.as_millis())
 988                                                .into(),
 989                                        ),
 990                                    ],
 991                                },
 992                            ))
 993                            .ok();
 994                    }
 995                }
 996            }
 997        }
 998    }
 999
1000    pub fn debug_info(
1001        &mut self,
1002        project: &Entity<Project>,
1003        cx: &mut Context<Self>,
1004    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1005        let project_state = self.get_or_init_project(project, cx);
1006        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1007        project_state.debug_tx = Some(debug_watch_tx);
1008        debug_watch_rx
1009    }
1010
1011    fn handle_project_event(
1012        &mut self,
1013        project: Entity<Project>,
1014        event: &project::Event,
1015        cx: &mut Context<Self>,
1016    ) {
1017        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1018            return;
1019        }
1020        // TODO [zeta2] init with recent paths
1021        match event {
1022            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1023                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1024                    return;
1025                };
1026                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1027                if let Some(path) = path {
1028                    if let Some(ix) = project_state
1029                        .recent_paths
1030                        .iter()
1031                        .position(|probe| probe == &path)
1032                    {
1033                        project_state.recent_paths.remove(ix);
1034                    }
1035                    project_state.recent_paths.push_front(path);
1036                }
1037            }
1038            project::Event::DiagnosticsUpdated { .. } => {
1039                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1040                    self.refresh_prediction_from_diagnostics(
1041                        project,
1042                        DiagnosticSearchScope::Global,
1043                        cx,
1044                    );
1045                }
1046            }
1047            _ => (),
1048        }
1049    }
1050
1051    fn register_buffer_impl<'a>(
1052        project_state: &'a mut ProjectState,
1053        buffer: &Entity<Buffer>,
1054        project: &Entity<Project>,
1055        cx: &mut Context<Self>,
1056    ) -> &'a mut RegisteredBuffer {
1057        let buffer_id = buffer.entity_id();
1058
1059        if let Some(file) = buffer.read(cx).file() {
1060            let worktree_id = file.worktree_id(cx);
1061            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1062                project_state
1063                    .license_detection_watchers
1064                    .entry(worktree_id)
1065                    .or_insert_with(|| {
1066                        let project_entity_id = project.entity_id();
1067                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1068                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1069                            else {
1070                                return;
1071                            };
1072                            project_state
1073                                .license_detection_watchers
1074                                .remove(&worktree_id);
1075                        })
1076                        .detach();
1077                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1078                    });
1079            }
1080        }
1081
1082        match project_state.registered_buffers.entry(buffer_id) {
1083            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1084            hash_map::Entry::Vacant(entry) => {
1085                let buf = buffer.read(cx);
1086                let snapshot = buf.text_snapshot();
1087                let file = buf.file().cloned();
1088                let project_entity_id = project.entity_id();
1089                entry.insert(RegisteredBuffer {
1090                    snapshot,
1091                    file,
1092                    last_position: None,
1093                    _subscriptions: [
1094                        cx.subscribe(buffer, {
1095                            let project = project.downgrade();
1096                            move |this, buffer, event, cx| {
1097                                if let language::BufferEvent::Edited = event
1098                                    && let Some(project) = project.upgrade()
1099                                {
1100                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1101                                }
1102                            }
1103                        }),
1104                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1105                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1106                            else {
1107                                return;
1108                            };
1109                            project_state.registered_buffers.remove(&buffer_id);
1110                        }),
1111                    ],
1112                })
1113            }
1114        }
1115    }
1116
1117    fn report_changes_for_buffer(
1118        &mut self,
1119        buffer: &Entity<Buffer>,
1120        project: &Entity<Project>,
1121        is_predicted: bool,
1122        cx: &mut Context<Self>,
1123    ) {
1124        let project_state = self.get_or_init_project(project, cx);
1125        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1126
1127        let buf = buffer.read(cx);
1128        let new_file = buf.file().cloned();
1129        let new_snapshot = buf.text_snapshot();
1130        if new_snapshot.version == registered_buffer.snapshot.version {
1131            return;
1132        }
1133
1134        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1135        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1136        let mut num_edits = 0usize;
1137        let mut total_deleted = 0usize;
1138        let mut total_inserted = 0usize;
1139        let mut edit_range: Option<Range<Anchor>> = None;
1140        let mut last_offset: Option<usize> = None;
1141
1142        for (edit, anchor_range) in
1143            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1144        {
1145            num_edits += 1;
1146            total_deleted += edit.old.len();
1147            total_inserted += edit.new.len();
1148            edit_range = Some(match edit_range {
1149                None => anchor_range,
1150                Some(acc) => acc.start..anchor_range.end,
1151            });
1152            last_offset = Some(edit.new.end);
1153        }
1154
1155        let Some(edit_range) = edit_range else {
1156            return;
1157        };
1158
1159        let action_type = match (total_deleted, total_inserted, num_edits) {
1160            (0, ins, n) if ins == n => UserActionType::InsertChar,
1161            (0, _, _) => UserActionType::InsertSelection,
1162            (del, 0, n) if del == n => UserActionType::DeleteChar,
1163            (_, 0, _) => UserActionType::DeleteSelection,
1164            (_, ins, n) if ins == n => UserActionType::InsertChar,
1165            (_, _, _) => UserActionType::InsertSelection,
1166        };
1167
1168        if let Some(offset) = last_offset {
1169            let point = new_snapshot.offset_to_point(offset);
1170            let timestamp_epoch_ms = SystemTime::now()
1171                .duration_since(UNIX_EPOCH)
1172                .map(|d| d.as_millis() as u64)
1173                .unwrap_or(0);
1174            project_state.record_user_action(UserActionRecord {
1175                action_type,
1176                buffer_id: buffer.entity_id(),
1177                line_number: point.row,
1178                offset,
1179                timestamp_epoch_ms,
1180            });
1181        }
1182
1183        let events = &mut project_state.events;
1184
1185        let now = cx.background_executor().now();
1186        if let Some(last_event) = project_state.last_event.as_mut() {
1187            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1188                == last_event.new_snapshot.remote_id()
1189                && old_snapshot.version == last_event.new_snapshot.version;
1190
1191            let prediction_source_changed = is_predicted != last_event.predicted;
1192
1193            let should_coalesce = is_next_snapshot_of_same_buffer
1194                && !prediction_source_changed
1195                && last_event
1196                    .edit_range
1197                    .as_ref()
1198                    .is_some_and(|last_edit_range| {
1199                        lines_between_ranges(
1200                            &edit_range.to_point(&new_snapshot),
1201                            &last_edit_range.to_point(&new_snapshot),
1202                        ) <= CHANGE_GROUPING_LINE_SPAN
1203                    });
1204
1205            if should_coalesce {
1206                let pause_elapsed = last_event
1207                    .last_edit_time
1208                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1209                    .unwrap_or(false);
1210                if pause_elapsed {
1211                    last_event.snapshot_after_last_editing_pause =
1212                        Some(last_event.new_snapshot.clone());
1213                }
1214
1215                last_event.edit_range = Some(edit_range);
1216                last_event.new_snapshot = new_snapshot;
1217                last_event.last_edit_time = Some(now);
1218                return;
1219            }
1220        }
1221
1222        if let Some(event) = project_state.last_event.take() {
1223            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1224                if events.len() + 1 >= EVENT_COUNT_MAX {
1225                    events.pop_front();
1226                }
1227                events.push_back(event);
1228            }
1229        }
1230
1231        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1232
1233        project_state.last_event = Some(LastEvent {
1234            old_file,
1235            new_file,
1236            old_snapshot,
1237            new_snapshot,
1238            edit_range: Some(edit_range),
1239            predicted: is_predicted,
1240            snapshot_after_last_editing_pause: None,
1241            last_edit_time: Some(now),
1242        });
1243    }
1244
1245    fn prediction_at(
1246        &mut self,
1247        buffer: &Entity<Buffer>,
1248        position: Option<language::Anchor>,
1249        project: &Entity<Project>,
1250        cx: &App,
1251    ) -> Option<BufferEditPrediction<'_>> {
1252        let project_state = self.projects.get_mut(&project.entity_id())?;
1253        if let Some(position) = position
1254            && let Some(buffer) = project_state
1255                .registered_buffers
1256                .get_mut(&buffer.entity_id())
1257        {
1258            buffer.last_position = Some(position);
1259        }
1260
1261        let CurrentEditPrediction {
1262            requested_by,
1263            prediction,
1264            ..
1265        } = project_state.current_prediction.as_ref()?;
1266
1267        if prediction.targets_buffer(buffer.read(cx)) {
1268            Some(BufferEditPrediction::Local { prediction })
1269        } else {
1270            let show_jump = match requested_by {
1271                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1272                    requested_by_buffer_id == &buffer.entity_id()
1273                }
1274                PredictionRequestedBy::DiagnosticsUpdate => true,
1275            };
1276
1277            if show_jump {
1278                Some(BufferEditPrediction::Jump { prediction })
1279            } else {
1280                None
1281            }
1282        }
1283    }
1284
1285    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1286        let Some(current_prediction) = self
1287            .projects
1288            .get_mut(&project.entity_id())
1289            .and_then(|project_state| project_state.current_prediction.take())
1290        else {
1291            return;
1292        };
1293
1294        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1295
1296        // can't hold &mut project_state ref across report_changes_for_buffer_call
1297        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1298            return;
1299        };
1300
1301        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1302            project_state.cancel_pending_prediction(pending_prediction, cx);
1303        }
1304
1305        match self.edit_prediction_model {
1306            EditPredictionModel::Sweep => {
1307                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1308            }
1309            EditPredictionModel::Mercury => {
1310                mercury::edit_prediction_accepted(
1311                    current_prediction.prediction.id,
1312                    self.client.http_client(),
1313                    cx,
1314                );
1315            }
1316            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1317                let is_cloud = !matches!(
1318                    all_language_settings(None, cx).edit_predictions.provider,
1319                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1320                );
1321                if is_cloud {
1322                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1323                }
1324            }
1325            EditPredictionModel::Fim { .. } => {}
1326        }
1327    }
1328
1329    async fn handle_rejected_predictions(
1330        rx: UnboundedReceiver<EditPredictionRejection>,
1331        client: Arc<Client>,
1332        llm_token: LlmApiToken,
1333        app_version: Version,
1334        background_executor: BackgroundExecutor,
1335    ) {
1336        let mut rx = std::pin::pin!(rx.peekable());
1337        let mut batched = Vec::new();
1338
1339        while let Some(rejection) = rx.next().await {
1340            batched.push(rejection);
1341
1342            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1343                select_biased! {
1344                    next = rx.as_mut().peek().fuse() => {
1345                        if next.is_some() {
1346                            continue;
1347                        }
1348                    }
1349                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1350                }
1351            }
1352
1353            let url = client
1354                .http_client()
1355                .build_zed_llm_url("/predict_edits/reject", &[])
1356                .unwrap();
1357
1358            let flush_count = batched
1359                .len()
1360                // in case items have accumulated after failure
1361                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1362            let start = batched.len() - flush_count;
1363
1364            let body = RejectEditPredictionsBodyRef {
1365                rejections: &batched[start..],
1366            };
1367
1368            let result = Self::send_api_request::<()>(
1369                |builder| {
1370                    let req = builder
1371                        .uri(url.as_ref())
1372                        .body(serde_json::to_string(&body)?.into());
1373                    anyhow::Ok(req?)
1374                },
1375                client.clone(),
1376                llm_token.clone(),
1377                app_version.clone(),
1378                true,
1379            )
1380            .await;
1381
1382            if result.log_err().is_some() {
1383                batched.drain(start..);
1384            }
1385        }
1386    }
1387
1388    fn reject_current_prediction(
1389        &mut self,
1390        reason: EditPredictionRejectReason,
1391        project: &Entity<Project>,
1392        cx: &App,
1393    ) {
1394        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1395            project_state.pending_predictions.clear();
1396            if let Some(prediction) = project_state.current_prediction.take() {
1397                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1398            }
1399        };
1400    }
1401
1402    fn did_show_current_prediction(
1403        &mut self,
1404        project: &Entity<Project>,
1405        display_type: edit_prediction_types::SuggestionDisplayType,
1406        cx: &mut Context<Self>,
1407    ) {
1408        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1409            return;
1410        };
1411
1412        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1413            return;
1414        };
1415
1416        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1417        let previous_shown_with = current_prediction.shown_with;
1418
1419        if previous_shown_with.is_none() || !is_jump {
1420            current_prediction.shown_with = Some(display_type);
1421        }
1422
1423        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1424
1425        if is_first_non_jump_show {
1426            current_prediction.was_shown = true;
1427        }
1428
1429        let display_type_changed = previous_shown_with != Some(display_type);
1430
1431        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1432            sweep_ai::edit_prediction_shown(
1433                &self.sweep_ai,
1434                self.client.clone(),
1435                &current_prediction.prediction,
1436                display_type,
1437                cx,
1438            );
1439        }
1440
1441        if is_first_non_jump_show {
1442            self.shown_predictions
1443                .push_front(current_prediction.prediction.clone());
1444            if self.shown_predictions.len() > 50 {
1445                let completion = self.shown_predictions.pop_back().unwrap();
1446                self.rated_predictions.remove(&completion.id);
1447            }
1448        }
1449    }
1450
1451    fn reject_prediction(
1452        &mut self,
1453        prediction_id: EditPredictionId,
1454        reason: EditPredictionRejectReason,
1455        was_shown: bool,
1456        cx: &App,
1457    ) {
1458        match self.edit_prediction_model {
1459            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1460                let is_cloud = !matches!(
1461                    all_language_settings(None, cx).edit_predictions.provider,
1462                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1463                );
1464                if is_cloud {
1465                    self.reject_predictions_tx
1466                        .unbounded_send(EditPredictionRejection {
1467                            request_id: prediction_id.to_string(),
1468                            reason,
1469                            was_shown,
1470                        })
1471                        .log_err();
1472                }
1473            }
1474            EditPredictionModel::Mercury => {
1475                mercury::edit_prediction_rejected(
1476                    prediction_id,
1477                    was_shown,
1478                    reason,
1479                    self.client.http_client(),
1480                    cx,
1481                );
1482            }
1483            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1484        }
1485    }
1486
1487    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1488        self.projects
1489            .get(&project.entity_id())
1490            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1491    }
1492
1493    pub fn refresh_prediction_from_buffer(
1494        &mut self,
1495        project: Entity<Project>,
1496        buffer: Entity<Buffer>,
1497        position: language::Anchor,
1498        cx: &mut Context<Self>,
1499    ) {
1500        self.queue_prediction_refresh(
1501            project.clone(),
1502            PredictEditsRequestTrigger::Other,
1503            buffer.entity_id(),
1504            cx,
1505            move |this, cx| {
1506                let Some(request_task) = this
1507                    .update(cx, |this, cx| {
1508                        this.request_prediction(
1509                            &project,
1510                            &buffer,
1511                            position,
1512                            PredictEditsRequestTrigger::Other,
1513                            cx,
1514                        )
1515                    })
1516                    .log_err()
1517                else {
1518                    return Task::ready(anyhow::Ok(None));
1519                };
1520
1521                cx.spawn(async move |_cx| {
1522                    request_task.await.map(|prediction_result| {
1523                        prediction_result.map(|prediction_result| {
1524                            (
1525                                prediction_result,
1526                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1527                            )
1528                        })
1529                    })
1530                })
1531            },
1532        )
1533    }
1534
1535    pub fn refresh_prediction_from_diagnostics(
1536        &mut self,
1537        project: Entity<Project>,
1538        scope: DiagnosticSearchScope,
1539        cx: &mut Context<Self>,
1540    ) {
1541        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1542            return;
1543        }
1544
1545        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1546            return;
1547        };
1548
1549        // Prefer predictions from buffer
1550        if project_state.current_prediction.is_some() {
1551            return;
1552        }
1553
1554        self.queue_prediction_refresh(
1555            project.clone(),
1556            PredictEditsRequestTrigger::Diagnostics,
1557            project.entity_id(),
1558            cx,
1559            move |this, cx| {
1560                let Some((active_buffer, snapshot, cursor_point)) = this
1561                    .read_with(cx, |this, cx| {
1562                        let project_state = this.projects.get(&project.entity_id())?;
1563                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1564                        let snapshot = buffer.read(cx).snapshot();
1565
1566                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1567                            return None;
1568                        }
1569
1570                        let cursor_point = position
1571                            .map(|pos| pos.to_point(&snapshot))
1572                            .unwrap_or_default();
1573
1574                        Some((buffer, snapshot, cursor_point))
1575                    })
1576                    .log_err()
1577                    .flatten()
1578                else {
1579                    return Task::ready(anyhow::Ok(None));
1580                };
1581
1582                cx.spawn(async move |cx| {
1583                    let diagnostic_search_range = match scope {
1584                        DiagnosticSearchScope::Local => {
1585                            let diagnostic_search_start =
1586                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1587                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1588                            Point::new(diagnostic_search_start, 0)
1589                                ..Point::new(diagnostic_search_end, 0)
1590                        }
1591                        DiagnosticSearchScope::Global => Default::default(),
1592                    };
1593
1594                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1595                        active_buffer,
1596                        &snapshot,
1597                        diagnostic_search_range,
1598                        cursor_point,
1599                        &project,
1600                        cx,
1601                    )
1602                    .await?
1603                    else {
1604                        return anyhow::Ok(None);
1605                    };
1606
1607                    let Some(prediction_result) = this
1608                        .update(cx, |this, cx| {
1609                            this.request_prediction(
1610                                &project,
1611                                &jump_buffer,
1612                                jump_position,
1613                                PredictEditsRequestTrigger::Diagnostics,
1614                                cx,
1615                            )
1616                        })?
1617                        .await?
1618                    else {
1619                        return anyhow::Ok(None);
1620                    };
1621
1622                    this.update(cx, |this, cx| {
1623                        Some((
1624                            if this
1625                                .get_or_init_project(&project, cx)
1626                                .current_prediction
1627                                .is_none()
1628                            {
1629                                prediction_result
1630                            } else {
1631                                EditPredictionResult {
1632                                    id: prediction_result.id,
1633                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1634                                }
1635                            },
1636                            PredictionRequestedBy::DiagnosticsUpdate,
1637                        ))
1638                    })
1639                })
1640            },
1641        );
1642    }
1643
1644    fn predictions_enabled_at(
1645        snapshot: &BufferSnapshot,
1646        position: Option<language::Anchor>,
1647        cx: &App,
1648    ) -> bool {
1649        let file = snapshot.file();
1650        let all_settings = all_language_settings(file, cx);
1651        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1652            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1653        {
1654            return false;
1655        }
1656
1657        if let Some(last_position) = position {
1658            let settings = snapshot.settings_at(last_position, cx);
1659
1660            if !settings.edit_predictions_disabled_in.is_empty()
1661                && let Some(scope) = snapshot.language_scope_at(last_position)
1662                && let Some(scope_name) = scope.override_name()
1663                && settings
1664                    .edit_predictions_disabled_in
1665                    .iter()
1666                    .any(|s| s == scope_name)
1667            {
1668                return false;
1669            }
1670        }
1671
1672        true
1673    }
1674
1675    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1676}
1677
1678fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1679    match provider {
1680        EditPredictionProvider::Zed
1681        | EditPredictionProvider::Sweep
1682        | EditPredictionProvider::Mercury
1683        | EditPredictionProvider::Ollama
1684        | EditPredictionProvider::OpenAiCompatibleApi
1685        | EditPredictionProvider::Experimental(_) => true,
1686        EditPredictionProvider::None
1687        | EditPredictionProvider::Copilot
1688        | EditPredictionProvider::Supermaven
1689        | EditPredictionProvider::Codestral => false,
1690    }
1691}
1692
1693impl EditPredictionStore {
1694    fn queue_prediction_refresh(
1695        &mut self,
1696        project: Entity<Project>,
1697        request_trigger: PredictEditsRequestTrigger,
1698        throttle_entity: EntityId,
1699        cx: &mut Context<Self>,
1700        do_refresh: impl FnOnce(
1701            WeakEntity<Self>,
1702            &mut AsyncApp,
1703        )
1704            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1705        + 'static,
1706    ) {
1707        fn select_throttle(
1708            project_state: &mut ProjectState,
1709            request_trigger: PredictEditsRequestTrigger,
1710        ) -> &mut Option<(EntityId, Instant)> {
1711            match request_trigger {
1712                PredictEditsRequestTrigger::Diagnostics => {
1713                    &mut project_state.last_jump_prediction_refresh
1714                }
1715                _ => &mut project_state.last_edit_prediction_refresh,
1716            }
1717        }
1718
1719        let (needs_acceptance_tracking, max_pending_predictions) =
1720            match all_language_settings(None, cx).edit_predictions.provider {
1721                EditPredictionProvider::Zed
1722                | EditPredictionProvider::Sweep
1723                | EditPredictionProvider::Mercury
1724                | EditPredictionProvider::Experimental(_) => (true, 2),
1725                EditPredictionProvider::Ollama => (false, 1),
1726                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1727                EditPredictionProvider::None
1728                | EditPredictionProvider::Copilot
1729                | EditPredictionProvider::Supermaven
1730                | EditPredictionProvider::Codestral => {
1731                    log::error!("queue_prediction_refresh called with non-store provider");
1732                    return;
1733                }
1734            };
1735
1736        let drop_on_cancel = !needs_acceptance_tracking;
1737        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1738        let project_state = self.get_or_init_project(&project, cx);
1739        let pending_prediction_id = project_state.next_pending_prediction_id;
1740        project_state.next_pending_prediction_id += 1;
1741        let last_request = *select_throttle(project_state, request_trigger);
1742
1743        let task = cx.spawn(async move |this, cx| {
1744            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1745                if throttle_entity != last_entity {
1746                    return None;
1747                }
1748                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1749            }) {
1750                cx.background_executor().timer(timeout).await;
1751            }
1752
1753            // If this task was cancelled before the throttle timeout expired,
1754            // do not perform a request.
1755            let mut is_cancelled = true;
1756            this.update(cx, |this, cx| {
1757                let project_state = this.get_or_init_project(&project, cx);
1758                let was_cancelled = project_state
1759                    .cancelled_predictions
1760                    .remove(&pending_prediction_id);
1761                if !was_cancelled {
1762                    let new_refresh = (throttle_entity, Instant::now());
1763                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
1764                    is_cancelled = false;
1765                }
1766            })
1767            .ok();
1768            if is_cancelled {
1769                return None;
1770            }
1771
1772            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1773            let new_prediction_id = new_prediction_result
1774                .as_ref()
1775                .map(|(prediction, _)| prediction.id.clone());
1776
1777            // When a prediction completes, remove it from the pending list, and cancel
1778            // any pending predictions that were enqueued before it.
1779            this.update(cx, |this, cx| {
1780                let project_state = this.get_or_init_project(&project, cx);
1781
1782                let is_cancelled = project_state
1783                    .cancelled_predictions
1784                    .remove(&pending_prediction_id);
1785
1786                let new_current_prediction = if !is_cancelled
1787                    && let Some((prediction_result, requested_by)) = new_prediction_result
1788                {
1789                    match prediction_result.prediction {
1790                        Ok(prediction) => {
1791                            let new_prediction = CurrentEditPrediction {
1792                                requested_by,
1793                                prediction,
1794                                was_shown: false,
1795                                shown_with: None,
1796                            };
1797
1798                            if let Some(current_prediction) =
1799                                project_state.current_prediction.as_ref()
1800                            {
1801                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1802                                {
1803                                    this.reject_current_prediction(
1804                                        EditPredictionRejectReason::Replaced,
1805                                        &project,
1806                                        cx,
1807                                    );
1808
1809                                    Some(new_prediction)
1810                                } else {
1811                                    this.reject_prediction(
1812                                        new_prediction.prediction.id,
1813                                        EditPredictionRejectReason::CurrentPreferred,
1814                                        false,
1815                                        cx,
1816                                    );
1817                                    None
1818                                }
1819                            } else {
1820                                Some(new_prediction)
1821                            }
1822                        }
1823                        Err(reject_reason) => {
1824                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1825                            None
1826                        }
1827                    }
1828                } else {
1829                    None
1830                };
1831
1832                let project_state = this.get_or_init_project(&project, cx);
1833
1834                if let Some(new_prediction) = new_current_prediction {
1835                    project_state.current_prediction = Some(new_prediction);
1836                }
1837
1838                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1839                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1840                    if pending_prediction.id == pending_prediction_id {
1841                        pending_predictions.remove(ix);
1842                        for pending_prediction in pending_predictions.drain(0..ix) {
1843                            project_state.cancel_pending_prediction(pending_prediction, cx)
1844                        }
1845                        break;
1846                    }
1847                }
1848                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1849                cx.notify();
1850            })
1851            .ok();
1852
1853            new_prediction_id
1854        });
1855
1856        if project_state.pending_predictions.len() < max_pending_predictions {
1857            project_state.pending_predictions.push(PendingPrediction {
1858                id: pending_prediction_id,
1859                task,
1860                drop_on_cancel,
1861            });
1862        } else {
1863            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1864            project_state.pending_predictions.push(PendingPrediction {
1865                id: pending_prediction_id,
1866                task,
1867                drop_on_cancel,
1868            });
1869            project_state.cancel_pending_prediction(pending_prediction, cx);
1870        }
1871    }
1872
1873    pub fn request_prediction(
1874        &mut self,
1875        project: &Entity<Project>,
1876        active_buffer: &Entity<Buffer>,
1877        position: language::Anchor,
1878        trigger: PredictEditsRequestTrigger,
1879        cx: &mut Context<Self>,
1880    ) -> Task<Result<Option<EditPredictionResult>>> {
1881        self.request_prediction_internal(
1882            project.clone(),
1883            active_buffer.clone(),
1884            position,
1885            trigger,
1886            cx.has_flag::<Zeta2FeatureFlag>(),
1887            cx,
1888        )
1889    }
1890
1891    fn request_prediction_internal(
1892        &mut self,
1893        project: Entity<Project>,
1894        active_buffer: Entity<Buffer>,
1895        position: language::Anchor,
1896        trigger: PredictEditsRequestTrigger,
1897        allow_jump: bool,
1898        cx: &mut Context<Self>,
1899    ) -> Task<Result<Option<EditPredictionResult>>> {
1900        self.get_or_init_project(&project, cx);
1901        let project_state = self.projects.get(&project.entity_id()).unwrap();
1902        let stored_events = project_state.events(cx);
1903        let has_events = !stored_events.is_empty();
1904        let events: Vec<Arc<zeta_prompt::Event>> =
1905            stored_events.into_iter().map(|e| e.event).collect();
1906        let debug_tx = project_state.debug_tx.clone();
1907
1908        let snapshot = active_buffer.read(cx).snapshot();
1909        let cursor_point = position.to_point(&snapshot);
1910        let current_offset = position.to_offset(&snapshot);
1911
1912        let mut user_actions: Vec<UserActionRecord> =
1913            project_state.user_actions.iter().cloned().collect();
1914
1915        if let Some(last_action) = user_actions.last() {
1916            if last_action.buffer_id == active_buffer.entity_id()
1917                && current_offset != last_action.offset
1918            {
1919                let timestamp_epoch_ms = SystemTime::now()
1920                    .duration_since(UNIX_EPOCH)
1921                    .map(|d| d.as_millis() as u64)
1922                    .unwrap_or(0);
1923                user_actions.push(UserActionRecord {
1924                    action_type: UserActionType::CursorMovement,
1925                    buffer_id: active_buffer.entity_id(),
1926                    line_number: cursor_point.row,
1927                    offset: current_offset,
1928                    timestamp_epoch_ms,
1929                });
1930            }
1931        }
1932        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1933        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1934        let diagnostic_search_range =
1935            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1936
1937        let related_files = self.context_for_project(&project, cx);
1938
1939        let inputs = EditPredictionModelInput {
1940            project: project.clone(),
1941            buffer: active_buffer,
1942            snapshot: snapshot,
1943            position,
1944            events,
1945            related_files,
1946            recent_paths: project_state.recent_paths.clone(),
1947            trigger,
1948            diagnostic_search_range: diagnostic_search_range,
1949            debug_tx,
1950            user_actions,
1951        };
1952
1953        let task = match self.edit_prediction_model {
1954            EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
1955                self,
1956                inputs,
1957                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1958                cx,
1959            ),
1960            EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
1961                self,
1962                inputs,
1963                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1964                cx,
1965            ),
1966            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
1967            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1968            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1969        };
1970
1971        cx.spawn(async move |this, cx| {
1972            let prediction = task.await?;
1973
1974            if prediction.is_none() && allow_jump && has_events {
1975                this.update(cx, |this, cx| {
1976                    this.refresh_prediction_from_diagnostics(
1977                        project,
1978                        DiagnosticSearchScope::Local,
1979                        cx,
1980                    );
1981                })?;
1982                return anyhow::Ok(None);
1983            }
1984
1985            Ok(prediction)
1986        })
1987    }
1988
1989    pub(crate) async fn next_diagnostic_location(
1990        active_buffer: Entity<Buffer>,
1991        active_buffer_snapshot: &BufferSnapshot,
1992        active_buffer_diagnostic_search_range: Range<Point>,
1993        active_buffer_cursor_point: Point,
1994        project: &Entity<Project>,
1995        cx: &mut AsyncApp,
1996    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1997        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
1998            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
1999            .flat_map(|(_, _, _, selections)| {
2000                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2001            })
2002            .collect();
2003
2004        let mut jump_location = active_buffer_snapshot
2005            .diagnostic_groups(None)
2006            .into_iter()
2007            .filter_map(|(_, group)| {
2008                let range = &group.entries[group.primary_ix]
2009                    .range
2010                    .to_point(&active_buffer_snapshot);
2011                if range.overlaps(&active_buffer_diagnostic_search_range) {
2012                    return None;
2013                }
2014                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2015                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2016                });
2017                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2018                    <= DIAGNOSTIC_LINES_RANGE;
2019                if near_collaborator && !near_local {
2020                    return None;
2021                }
2022                Some(range.start)
2023            })
2024            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2025            .map(|position| {
2026                (
2027                    active_buffer.clone(),
2028                    active_buffer_snapshot.anchor_before(position),
2029                )
2030            });
2031
2032        if jump_location.is_none() {
2033            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2034                let file = buffer.file()?;
2035
2036                Some(ProjectPath {
2037                    worktree_id: file.worktree_id(cx),
2038                    path: file.path().clone(),
2039                })
2040            });
2041
2042            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2043                project
2044                    .diagnostic_summaries(false, cx)
2045                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2046                    .map(|(path, _, _)| {
2047                        let shared_prefix = path
2048                            .path
2049                            .components()
2050                            .zip(
2051                                active_buffer_path
2052                                    .as_ref()
2053                                    .map(|p| p.path.components())
2054                                    .unwrap_or_default(),
2055                            )
2056                            .take_while(|(a, b)| a == b)
2057                            .count();
2058                        (path, shared_prefix)
2059                    })
2060                    .collect()
2061            });
2062
2063            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2064
2065            for (path, _) in candidates {
2066                let candidate_buffer = project
2067                    .update(cx, |project, cx| project.open_buffer(path, cx))
2068                    .await?;
2069
2070                let (has_collaborators, diagnostic_position) =
2071                    candidate_buffer.read_with(cx, |buffer, _cx| {
2072                        let snapshot = buffer.snapshot();
2073                        let has_collaborators = snapshot
2074                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2075                            .next()
2076                            .is_some();
2077                        let position = buffer
2078                            .buffer_diagnostics(None)
2079                            .into_iter()
2080                            .min_by_key(|entry| entry.diagnostic.severity)
2081                            .map(|entry| entry.range.start);
2082                        (has_collaborators, position)
2083                    });
2084
2085                if has_collaborators {
2086                    continue;
2087                }
2088
2089                if let Some(position) = diagnostic_position {
2090                    jump_location = Some((candidate_buffer, position));
2091                    break;
2092                }
2093            }
2094        }
2095
2096        anyhow::Ok(jump_location)
2097    }
2098
2099    async fn send_raw_llm_request(
2100        request: RawCompletionRequest,
2101        client: Arc<Client>,
2102        custom_url: Option<Arc<Url>>,
2103        llm_token: LlmApiToken,
2104        app_version: Version,
2105    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2106        let url = if let Some(custom_url) = custom_url {
2107            custom_url.as_ref().clone()
2108        } else {
2109            client
2110                .http_client()
2111                .build_zed_llm_url("/predict_edits/raw", &[])?
2112        };
2113
2114        Self::send_api_request(
2115            |builder| {
2116                let req = builder
2117                    .uri(url.as_ref())
2118                    .body(serde_json::to_string(&request)?.into());
2119                Ok(req?)
2120            },
2121            client,
2122            llm_token,
2123            app_version,
2124            true,
2125        )
2126        .await
2127    }
2128
2129    pub(crate) async fn send_v3_request(
2130        input: ZetaPromptInput,
2131        client: Arc<Client>,
2132        llm_token: LlmApiToken,
2133        app_version: Version,
2134        trigger: PredictEditsRequestTrigger,
2135    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2136        let url = client
2137            .http_client()
2138            .build_zed_llm_url("/predict_edits/v3", &[])?;
2139
2140        let request = PredictEditsV3Request { input, trigger };
2141
2142        let json_bytes = serde_json::to_vec(&request)?;
2143        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2144
2145        Self::send_api_request(
2146            |builder| {
2147                let req = builder
2148                    .uri(url.as_ref())
2149                    .header("Content-Encoding", "zstd")
2150                    .body(compressed.clone().into());
2151                Ok(req?)
2152            },
2153            client,
2154            llm_token,
2155            app_version,
2156            true,
2157        )
2158        .await
2159    }
2160
2161    fn handle_api_response<T>(
2162        this: &WeakEntity<Self>,
2163        response: Result<(T, Option<EditPredictionUsage>)>,
2164        cx: &mut gpui::AsyncApp,
2165    ) -> Result<T> {
2166        match response {
2167            Ok((data, usage)) => {
2168                if let Some(usage) = usage {
2169                    this.update(cx, |this, cx| {
2170                        this.user_store.update(cx, |user_store, cx| {
2171                            user_store.update_edit_prediction_usage(usage, cx);
2172                        });
2173                    })
2174                    .ok();
2175                }
2176                Ok(data)
2177            }
2178            Err(err) => {
2179                if err.is::<ZedUpdateRequiredError>() {
2180                    cx.update(|cx| {
2181                        this.update(cx, |this, _cx| {
2182                            this.update_required = true;
2183                        })
2184                        .ok();
2185
2186                        let error_message: SharedString = err.to_string().into();
2187                        show_app_notification(
2188                            NotificationId::unique::<ZedUpdateRequiredError>(),
2189                            cx,
2190                            move |cx| {
2191                                cx.new(|cx| {
2192                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2193                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2194                                })
2195                            },
2196                        );
2197                    });
2198                }
2199                Err(err)
2200            }
2201        }
2202    }
2203
2204    async fn send_api_request<Res>(
2205        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2206        client: Arc<Client>,
2207        llm_token: LlmApiToken,
2208        app_version: Version,
2209        require_auth: bool,
2210    ) -> Result<(Res, Option<EditPredictionUsage>)>
2211    where
2212        Res: DeserializeOwned,
2213    {
2214        let http_client = client.http_client();
2215
2216        let mut token = if require_auth {
2217            Some(llm_token.acquire(&client).await?)
2218        } else {
2219            llm_token.acquire(&client).await.ok()
2220        };
2221        let mut did_retry = false;
2222
2223        loop {
2224            let request_builder = http_client::Request::builder().method(Method::POST);
2225
2226            let mut request_builder = request_builder
2227                .header("Content-Type", "application/json")
2228                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2229
2230            // Only add Authorization header if we have a token
2231            if let Some(ref token_value) = token {
2232                request_builder =
2233                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2234            }
2235
2236            let request = build(request_builder)?;
2237
2238            let mut response = http_client.send(request).await?;
2239
2240            if let Some(minimum_required_version) = response
2241                .headers()
2242                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2243                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2244            {
2245                anyhow::ensure!(
2246                    app_version >= minimum_required_version,
2247                    ZedUpdateRequiredError {
2248                        minimum_version: minimum_required_version
2249                    }
2250                );
2251            }
2252
2253            if response.status().is_success() {
2254                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2255
2256                let mut body = Vec::new();
2257                response.body_mut().read_to_end(&mut body).await?;
2258                return Ok((serde_json::from_slice(&body)?, usage));
2259            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2260                did_retry = true;
2261                token = Some(llm_token.refresh(&client).await?);
2262            } else {
2263                let mut body = String::new();
2264                response.body_mut().read_to_string(&mut body).await?;
2265                anyhow::bail!(
2266                    "Request failed with status: {:?}\nBody: {}",
2267                    response.status(),
2268                    body
2269                );
2270            }
2271        }
2272    }
2273
2274    pub fn refresh_context(
2275        &mut self,
2276        project: &Entity<Project>,
2277        buffer: &Entity<language::Buffer>,
2278        cursor_position: language::Anchor,
2279        cx: &mut Context<Self>,
2280    ) {
2281        self.get_or_init_project(project, cx)
2282            .context
2283            .update(cx, |store, cx| {
2284                store.refresh(buffer.clone(), cursor_position, cx);
2285            });
2286    }
2287
2288    #[cfg(feature = "cli-support")]
2289    pub fn set_context_for_buffer(
2290        &mut self,
2291        project: &Entity<Project>,
2292        related_files: Vec<RelatedFile>,
2293        cx: &mut Context<Self>,
2294    ) {
2295        self.get_or_init_project(project, cx)
2296            .context
2297            .update(cx, |store, cx| {
2298                store.set_related_files(related_files, cx);
2299            });
2300    }
2301
2302    #[cfg(feature = "cli-support")]
2303    pub fn set_recent_paths_for_project(
2304        &mut self,
2305        project: &Entity<Project>,
2306        paths: impl IntoIterator<Item = project::ProjectPath>,
2307        cx: &mut Context<Self>,
2308    ) {
2309        let project_state = self.get_or_init_project(project, cx);
2310        project_state.recent_paths = paths.into_iter().collect();
2311    }
2312
2313    fn is_file_open_source(
2314        &self,
2315        project: &Entity<Project>,
2316        file: &Arc<dyn File>,
2317        cx: &App,
2318    ) -> bool {
2319        if !file.is_local() || file.is_private() {
2320            return false;
2321        }
2322        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2323            return false;
2324        };
2325        project_state
2326            .license_detection_watchers
2327            .get(&file.worktree_id(cx))
2328            .as_ref()
2329            .is_some_and(|watcher| watcher.is_project_open_source())
2330    }
2331
2332    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2333        self.data_collection_choice.is_enabled(cx)
2334    }
2335
2336    fn load_data_collection_choice() -> DataCollectionChoice {
2337        let choice = KEY_VALUE_STORE
2338            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2339            .log_err()
2340            .flatten();
2341
2342        match choice.as_deref() {
2343            Some("true") => DataCollectionChoice::Enabled,
2344            Some("false") => DataCollectionChoice::Disabled,
2345            Some(_) => {
2346                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2347                DataCollectionChoice::NotAnswered
2348            }
2349            None => DataCollectionChoice::NotAnswered,
2350        }
2351    }
2352
2353    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2354        self.data_collection_choice = self.data_collection_choice.toggle();
2355        let new_choice = self.data_collection_choice;
2356        let is_enabled = new_choice.is_enabled(cx);
2357        db::write_and_log(cx, move || {
2358            KEY_VALUE_STORE.write_kvp(
2359                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2360                is_enabled.to_string(),
2361            )
2362        });
2363    }
2364
2365    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2366        self.shown_predictions.iter()
2367    }
2368
2369    pub fn shown_completions_len(&self) -> usize {
2370        self.shown_predictions.len()
2371    }
2372
2373    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2374        self.rated_predictions.contains(id)
2375    }
2376
2377    pub fn rate_prediction(
2378        &mut self,
2379        prediction: &EditPrediction,
2380        rating: EditPredictionRating,
2381        feedback: String,
2382        cx: &mut Context<Self>,
2383    ) {
2384        let organization = self.user_store.read(cx).current_organization();
2385
2386        self.rated_predictions.insert(prediction.id.clone());
2387
2388        cx.background_spawn({
2389            let client = self.client.clone();
2390            let prediction_id = prediction.id.to_string();
2391            let inputs = serde_json::to_value(&prediction.inputs);
2392            let output = prediction
2393                .edit_preview
2394                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2395            async move {
2396                client
2397                    .cloud_client()
2398                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2399                        organization_id: organization.map(|organization| organization.id.clone()),
2400                        request_id: prediction_id,
2401                        rating: match rating {
2402                            EditPredictionRating::Positive => "positive".to_string(),
2403                            EditPredictionRating::Negative => "negative".to_string(),
2404                        },
2405                        inputs: inputs?,
2406                        output,
2407                        feedback,
2408                    })
2409                    .await?;
2410
2411                anyhow::Ok(())
2412            }
2413        })
2414        .detach_and_log_err(cx);
2415
2416        cx.notify();
2417    }
2418}
2419
2420fn merge_trailing_events_if_needed(
2421    events: &mut VecDeque<StoredEvent>,
2422    end_snapshot: &TextBufferSnapshot,
2423    latest_snapshot: &TextBufferSnapshot,
2424    latest_edit_range: &Range<Anchor>,
2425) {
2426    if let Some(last_event) = events.back() {
2427        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2428            return;
2429        }
2430    }
2431
2432    let mut next_old_event = None;
2433    let mut mergeable_count = 0;
2434    for old_event in events.iter().rev() {
2435        if let Some(next_old_event) = &next_old_event
2436            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2437        {
2438            break;
2439        }
2440        mergeable_count += 1;
2441        next_old_event = Some(old_event);
2442    }
2443
2444    if mergeable_count <= 1 {
2445        return;
2446    }
2447
2448    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2449    let oldest_event = events_to_merge.peek().unwrap();
2450    let oldest_snapshot = oldest_event.old_snapshot.clone();
2451
2452    if let Some((diff, edited_range)) =
2453        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2454    {
2455        let merged_event = match oldest_event.event.as_ref() {
2456            zeta_prompt::Event::BufferChange {
2457                old_path,
2458                path,
2459                in_open_source_repo,
2460                ..
2461            } => StoredEvent {
2462                event: Arc::new(zeta_prompt::Event::BufferChange {
2463                    old_path: old_path.clone(),
2464                    path: path.clone(),
2465                    diff,
2466                    in_open_source_repo: *in_open_source_repo,
2467                    predicted: events_to_merge.all(|e| {
2468                        matches!(
2469                            e.event.as_ref(),
2470                            zeta_prompt::Event::BufferChange {
2471                                predicted: true,
2472                                ..
2473                            }
2474                        )
2475                    }),
2476                }),
2477                old_snapshot: oldest_snapshot.clone(),
2478                edit_range: end_snapshot.anchor_before(edited_range.start)
2479                    ..end_snapshot.anchor_before(edited_range.end),
2480            },
2481        };
2482        events.truncate(events.len() - mergeable_count);
2483        events.push_back(merged_event);
2484    }
2485}
2486
2487pub(crate) fn filter_redundant_excerpts(
2488    mut related_files: Vec<RelatedFile>,
2489    cursor_path: &Path,
2490    cursor_row_range: Range<u32>,
2491) -> Vec<RelatedFile> {
2492    for file in &mut related_files {
2493        if file.path.as_ref() == cursor_path {
2494            file.excerpts.retain(|excerpt| {
2495                excerpt.row_range.start < cursor_row_range.start
2496                    || excerpt.row_range.end > cursor_row_range.end
2497            });
2498        }
2499    }
2500    related_files.retain(|file| !file.excerpts.is_empty());
2501    related_files
2502}
2503
2504#[derive(Error, Debug)]
2505#[error(
2506    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2507)]
2508pub struct ZedUpdateRequiredError {
2509    minimum_version: Version,
2510}
2511
2512#[derive(Debug, Clone, Copy)]
2513pub enum DataCollectionChoice {
2514    NotAnswered,
2515    Enabled,
2516    Disabled,
2517}
2518
2519impl DataCollectionChoice {
2520    pub fn is_enabled(self, cx: &App) -> bool {
2521        if cx.is_staff() {
2522            return true;
2523        }
2524        match self {
2525            Self::Enabled => true,
2526            Self::NotAnswered | Self::Disabled => false,
2527        }
2528    }
2529
2530    #[must_use]
2531    pub fn toggle(&self) -> DataCollectionChoice {
2532        match self {
2533            Self::Enabled => Self::Disabled,
2534            Self::Disabled => Self::Enabled,
2535            Self::NotAnswered => Self::Enabled,
2536        }
2537    }
2538}
2539
2540impl From<bool> for DataCollectionChoice {
2541    fn from(value: bool) -> Self {
2542        match value {
2543            true => DataCollectionChoice::Enabled,
2544            false => DataCollectionChoice::Disabled,
2545        }
2546    }
2547}
2548
2549struct ZedPredictUpsell;
2550
2551impl Dismissable for ZedPredictUpsell {
2552    const KEY: &'static str = "dismissed-edit-predict-upsell";
2553
2554    fn dismissed() -> bool {
2555        // To make this backwards compatible with older versions of Zed, we
2556        // check if the user has seen the previous Edit Prediction Onboarding
2557        // before, by checking the data collection choice which was written to
2558        // the database once the user clicked on "Accept and Enable"
2559        if KEY_VALUE_STORE
2560            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2561            .log_err()
2562            .is_some_and(|s| s.is_some())
2563        {
2564            return true;
2565        }
2566
2567        KEY_VALUE_STORE
2568            .read_kvp(Self::KEY)
2569            .log_err()
2570            .is_some_and(|s| s.is_some())
2571    }
2572}
2573
2574pub fn should_show_upsell_modal() -> bool {
2575    !ZedPredictUpsell::dismissed()
2576}
2577
2578pub fn init(cx: &mut App) {
2579    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2580        workspace.register_action(
2581            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2582                ZedPredictModal::toggle(
2583                    workspace,
2584                    workspace.user_store().clone(),
2585                    workspace.client().clone(),
2586                    window,
2587                    cx,
2588                )
2589            },
2590        );
2591
2592        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2593            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2594                settings
2595                    .project
2596                    .all_languages
2597                    .edit_predictions
2598                    .get_or_insert_default()
2599                    .provider = Some(EditPredictionProvider::None)
2600            });
2601        });
2602        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2603            EditPredictionStore::try_global(cx).and_then(|store| {
2604                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2605            })
2606        }
2607
2608        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2609            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2610                copilot_ui::initiate_sign_in(copilot, window, cx);
2611            }
2612        });
2613        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2614            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2615                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2616            }
2617        });
2618        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2619            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2620                copilot_ui::initiate_sign_out(copilot, window, cx);
2621            }
2622        });
2623    })
2624    .detach();
2625}