edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use text::Edit;
  40use workspace::Workspace;
  41use zeta_prompt::ZetaPromptInput;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59mod onboarding_modal;
  60pub mod open_ai_response;
  61mod prediction;
  62pub mod sweep_ai;
  63
  64pub mod udiff;
  65
  66mod capture_example;
  67mod zed_edit_prediction_delegate;
  68pub mod zeta1;
  69pub mod zeta2;
  70
  71#[cfg(test)]
  72mod edit_prediction_tests;
  73
  74use crate::capture_example::should_sample_edit_prediction_example_capture;
  75use crate::license_detection::LicenseDetectionWatcher;
  76use crate::mercury::Mercury;
  77use crate::onboarding_modal::ZedPredictModal;
  78pub use crate::prediction::EditPrediction;
  79pub use crate::prediction::EditPredictionId;
  80use crate::prediction::EditPredictionResult;
  81pub use crate::sweep_ai::SweepAi;
  82pub use capture_example::capture_example;
  83pub use language_model::ApiKeyState;
  84pub use telemetry_events::EditPredictionRating;
  85pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  86
  87actions!(
  88    edit_prediction,
  89    [
  90        /// Resets the edit prediction onboarding state.
  91        ResetOnboarding,
  92        /// Clears the edit prediction history.
  93        ClearHistory,
  94    ]
  95);
  96
  97/// Maximum number of events to track.
  98const EVENT_COUNT_MAX: usize = 6;
  99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 100const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 102const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 103
 104pub struct SweepFeatureFlag;
 105
 106impl FeatureFlag for SweepFeatureFlag {
 107    const NAME: &str = "sweep-ai";
 108}
 109
 110pub struct MercuryFeatureFlag;
 111
 112impl FeatureFlag for MercuryFeatureFlag {
 113    const NAME: &str = "mercury";
 114}
 115
 116static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
 117    LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
 118
 119pub struct Zeta2FeatureFlag;
 120
 121impl FeatureFlag for Zeta2FeatureFlag {
 122    const NAME: &'static str = "zeta2";
 123
 124    fn enabled_for_staff() -> bool {
 125        true
 126    }
 127}
 128
 129pub struct EditPredictionExampleCaptureFeatureFlag;
 130
 131impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 132    const NAME: &'static str = "edit-prediction-example-capture";
 133
 134    fn enabled_for_staff() -> bool {
 135        true
 136    }
 137}
 138
 139#[derive(Clone)]
 140struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 141
 142impl Global for EditPredictionStoreGlobal {}
 143
 144pub struct EditPredictionStore {
 145    client: Arc<Client>,
 146    user_store: Entity<UserStore>,
 147    llm_token: LlmApiToken,
 148    _llm_token_subscription: Subscription,
 149    projects: HashMap<EntityId, ProjectState>,
 150    use_context: bool,
 151    update_required: bool,
 152    edit_prediction_model: EditPredictionModel,
 153    pub sweep_ai: SweepAi,
 154    pub mercury: Mercury,
 155    data_collection_choice: DataCollectionChoice,
 156    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 157    shown_predictions: VecDeque<EditPrediction>,
 158    rated_predictions: HashSet<EditPredictionId>,
 159    custom_predict_edits_url: Option<Arc<Url>>,
 160}
 161
 162#[derive(Copy, Clone, Default, PartialEq, Eq)]
 163pub enum EditPredictionModel {
 164    #[default]
 165    Zeta1,
 166    Zeta2 {
 167        version: ZetaVersion,
 168    },
 169    Sweep,
 170    Mercury,
 171}
 172
 173pub struct EditPredictionModelInput {
 174    project: Entity<Project>,
 175    buffer: Entity<Buffer>,
 176    snapshot: BufferSnapshot,
 177    position: Anchor,
 178    events: Vec<Arc<zeta_prompt::Event>>,
 179    related_files: Vec<RelatedFile>,
 180    recent_paths: VecDeque<ProjectPath>,
 181    trigger: PredictEditsRequestTrigger,
 182    diagnostic_search_range: Range<Point>,
 183    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 184    pub user_actions: Vec<UserActionRecord>,
 185}
 186
 187#[derive(Debug)]
 188pub enum DebugEvent {
 189    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 190    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 191    EditPredictionStarted(EditPredictionStartedDebugEvent),
 192    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 193}
 194
 195#[derive(Debug)]
 196pub struct ContextRetrievalStartedDebugEvent {
 197    pub project_entity_id: EntityId,
 198    pub timestamp: Instant,
 199    pub search_prompt: String,
 200}
 201
 202#[derive(Debug)]
 203pub struct ContextRetrievalFinishedDebugEvent {
 204    pub project_entity_id: EntityId,
 205    pub timestamp: Instant,
 206    pub metadata: Vec<(&'static str, SharedString)>,
 207}
 208
 209#[derive(Debug)]
 210pub struct EditPredictionStartedDebugEvent {
 211    pub buffer: WeakEntity<Buffer>,
 212    pub position: Anchor,
 213    pub prompt: Option<String>,
 214}
 215
 216#[derive(Debug)]
 217pub struct EditPredictionFinishedDebugEvent {
 218    pub buffer: WeakEntity<Buffer>,
 219    pub position: Anchor,
 220    pub model_output: Option<String>,
 221}
 222
 223const USER_ACTION_HISTORY_SIZE: usize = 16;
 224
 225#[derive(Clone, Debug)]
 226pub struct UserActionRecord {
 227    pub action_type: UserActionType,
 228    pub buffer_id: EntityId,
 229    pub line_number: u32,
 230    pub offset: usize,
 231    pub timestamp_epoch_ms: u64,
 232}
 233
 234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 235pub enum UserActionType {
 236    InsertChar,
 237    InsertSelection,
 238    DeleteChar,
 239    DeleteSelection,
 240    CursorMovement,
 241}
 242
 243/// An event with associated metadata for reconstructing buffer state.
 244#[derive(Clone)]
 245pub struct StoredEvent {
 246    pub event: Arc<zeta_prompt::Event>,
 247    pub old_snapshot: TextBufferSnapshot,
 248}
 249
 250struct ProjectState {
 251    events: VecDeque<StoredEvent>,
 252    last_event: Option<LastEvent>,
 253    recent_paths: VecDeque<ProjectPath>,
 254    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 255    current_prediction: Option<CurrentEditPrediction>,
 256    next_pending_prediction_id: usize,
 257    pending_predictions: ArrayVec<PendingPrediction, 2>,
 258    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 259    last_prediction_refresh: Option<(EntityId, Instant)>,
 260    cancelled_predictions: HashSet<usize>,
 261    context: Entity<RelatedExcerptStore>,
 262    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 263    user_actions: VecDeque<UserActionRecord>,
 264    _subscription: gpui::Subscription,
 265    copilot: Option<Entity<Copilot>>,
 266}
 267
 268impl ProjectState {
 269    fn record_user_action(&mut self, action: UserActionRecord) {
 270        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 271            self.user_actions.pop_front();
 272        }
 273        self.user_actions.push_back(action);
 274    }
 275
 276    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 277        self.events
 278            .iter()
 279            .cloned()
 280            .chain(
 281                self.last_event
 282                    .as_ref()
 283                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 284            )
 285            .collect()
 286    }
 287
 288    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 289        self.events
 290            .iter()
 291            .cloned()
 292            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 293                let (one, two) = event.split_by_pause();
 294                let one = one.finalize(&self.license_detection_watchers, cx);
 295                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 296                one.into_iter().chain(two)
 297            }))
 298            .collect()
 299    }
 300
 301    fn cancel_pending_prediction(
 302        &mut self,
 303        pending_prediction: PendingPrediction,
 304        cx: &mut Context<EditPredictionStore>,
 305    ) {
 306        self.cancelled_predictions.insert(pending_prediction.id);
 307
 308        cx.spawn(async move |this, cx| {
 309            let Some(prediction_id) = pending_prediction.task.await else {
 310                return;
 311            };
 312
 313            this.update(cx, |this, _cx| {
 314                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 315            })
 316            .ok();
 317        })
 318        .detach()
 319    }
 320
 321    fn active_buffer(
 322        &self,
 323        project: &Entity<Project>,
 324        cx: &App,
 325    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 326        let project = project.read(cx);
 327        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 328        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 329        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 330        Some((active_buffer, registered_buffer.last_position))
 331    }
 332}
 333
 334#[derive(Debug, Clone)]
 335struct CurrentEditPrediction {
 336    pub requested_by: PredictionRequestedBy,
 337    pub prediction: EditPrediction,
 338    pub was_shown: bool,
 339    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 340}
 341
 342impl CurrentEditPrediction {
 343    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 344        let Some(new_edits) = self
 345            .prediction
 346            .interpolate(&self.prediction.buffer.read(cx))
 347        else {
 348            return false;
 349        };
 350
 351        if self.prediction.buffer != old_prediction.prediction.buffer {
 352            return true;
 353        }
 354
 355        let Some(old_edits) = old_prediction
 356            .prediction
 357            .interpolate(&old_prediction.prediction.buffer.read(cx))
 358        else {
 359            return true;
 360        };
 361
 362        let requested_by_buffer_id = self.requested_by.buffer_id();
 363
 364        // This reduces the occurrence of UI thrash from replacing edits
 365        //
 366        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 367        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 368            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 369            && old_edits.len() == 1
 370            && new_edits.len() == 1
 371        {
 372            let (old_range, old_text) = &old_edits[0];
 373            let (new_range, new_text) = &new_edits[0];
 374            new_range == old_range && new_text.starts_with(old_text.as_ref())
 375        } else {
 376            true
 377        }
 378    }
 379}
 380
 381#[derive(Debug, Clone)]
 382enum PredictionRequestedBy {
 383    DiagnosticsUpdate,
 384    Buffer(EntityId),
 385}
 386
 387impl PredictionRequestedBy {
 388    pub fn buffer_id(&self) -> Option<EntityId> {
 389        match self {
 390            PredictionRequestedBy::DiagnosticsUpdate => None,
 391            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 392        }
 393    }
 394}
 395
 396#[derive(Debug)]
 397struct PendingPrediction {
 398    id: usize,
 399    task: Task<Option<EditPredictionId>>,
 400}
 401
 402/// A prediction from the perspective of a buffer.
 403#[derive(Debug)]
 404enum BufferEditPrediction<'a> {
 405    Local { prediction: &'a EditPrediction },
 406    Jump { prediction: &'a EditPrediction },
 407}
 408
 409#[cfg(test)]
 410impl std::ops::Deref for BufferEditPrediction<'_> {
 411    type Target = EditPrediction;
 412
 413    fn deref(&self) -> &Self::Target {
 414        match self {
 415            BufferEditPrediction::Local { prediction } => prediction,
 416            BufferEditPrediction::Jump { prediction } => prediction,
 417        }
 418    }
 419}
 420
 421struct RegisteredBuffer {
 422    file: Option<Arc<dyn File>>,
 423    snapshot: TextBufferSnapshot,
 424    last_position: Option<Anchor>,
 425    _subscriptions: [gpui::Subscription; 2],
 426}
 427
 428#[derive(Clone)]
 429struct LastEvent {
 430    old_snapshot: TextBufferSnapshot,
 431    new_snapshot: TextBufferSnapshot,
 432    old_file: Option<Arc<dyn File>>,
 433    new_file: Option<Arc<dyn File>>,
 434    edit_range: Option<Range<Anchor>>,
 435    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 436    last_edit_time: Option<Instant>,
 437}
 438
 439impl LastEvent {
 440    pub fn finalize(
 441        &self,
 442        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 443        cx: &App,
 444    ) -> Option<StoredEvent> {
 445        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 446        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 447
 448        let in_open_source_repo =
 449            [self.new_file.as_ref(), self.old_file.as_ref()]
 450                .iter()
 451                .all(|file| {
 452                    file.is_some_and(|file| {
 453                        license_detection_watchers
 454                            .get(&file.worktree_id(cx))
 455                            .is_some_and(|watcher| watcher.is_project_open_source())
 456                    })
 457                });
 458
 459        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 460
 461        if path == old_path && diff.is_empty() {
 462            None
 463        } else {
 464            Some(StoredEvent {
 465                event: Arc::new(zeta_prompt::Event::BufferChange {
 466                    old_path,
 467                    path,
 468                    diff,
 469                    in_open_source_repo,
 470                    // TODO: Actually detect if this edit was predicted or not
 471                    predicted: false,
 472                }),
 473                old_snapshot: self.old_snapshot.clone(),
 474            })
 475        }
 476    }
 477
 478    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 479        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 480            return (self.clone(), None);
 481        };
 482
 483        let before = LastEvent {
 484            old_snapshot: self.old_snapshot.clone(),
 485            new_snapshot: boundary_snapshot.clone(),
 486            old_file: self.old_file.clone(),
 487            new_file: self.new_file.clone(),
 488            edit_range: None,
 489            snapshot_after_last_editing_pause: None,
 490            last_edit_time: self.last_edit_time,
 491        };
 492
 493        let after = LastEvent {
 494            old_snapshot: boundary_snapshot.clone(),
 495            new_snapshot: self.new_snapshot.clone(),
 496            old_file: self.old_file.clone(),
 497            new_file: self.new_file.clone(),
 498            edit_range: None,
 499            snapshot_after_last_editing_pause: None,
 500            last_edit_time: self.last_edit_time,
 501        };
 502
 503        (before, Some(after))
 504    }
 505}
 506
 507pub(crate) fn compute_diff_between_snapshots(
 508    old_snapshot: &TextBufferSnapshot,
 509    new_snapshot: &TextBufferSnapshot,
 510) -> Option<String> {
 511    let edits: Vec<Edit<usize>> = new_snapshot
 512        .edits_since::<usize>(&old_snapshot.version)
 513        .collect();
 514
 515    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 516
 517    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 518    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 519    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 520    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 521
 522    const CONTEXT_LINES: u32 = 3;
 523
 524    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 525    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 526    let old_context_end_row =
 527        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 528    let new_context_end_row =
 529        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 530
 531    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 532    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 533    let old_end_line_offset = old_snapshot
 534        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 535    let new_end_line_offset = new_snapshot
 536        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 537    let old_edit_range = old_start_line_offset..old_end_line_offset;
 538    let new_edit_range = new_start_line_offset..new_end_line_offset;
 539
 540    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 541    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 542
 543    let diff = language::unified_diff_with_offsets(
 544        &old_region_text,
 545        &new_region_text,
 546        old_context_start_row,
 547        new_context_start_row,
 548    );
 549
 550    Some(diff)
 551}
 552
 553fn buffer_path_with_id_fallback(
 554    file: Option<&Arc<dyn File>>,
 555    snapshot: &TextBufferSnapshot,
 556    cx: &App,
 557) -> Arc<Path> {
 558    if let Some(file) = file {
 559        file.full_path(cx).into()
 560    } else {
 561        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 562    }
 563}
 564
 565impl EditPredictionStore {
 566    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 567        cx.try_global::<EditPredictionStoreGlobal>()
 568            .map(|global| global.0.clone())
 569    }
 570
 571    pub fn global(
 572        client: &Arc<Client>,
 573        user_store: &Entity<UserStore>,
 574        cx: &mut App,
 575    ) -> Entity<Self> {
 576        cx.try_global::<EditPredictionStoreGlobal>()
 577            .map(|global| global.0.clone())
 578            .unwrap_or_else(|| {
 579                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 580                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 581                ep_store
 582            })
 583    }
 584
 585    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 586        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 587        let data_collection_choice = Self::load_data_collection_choice();
 588
 589        let llm_token = LlmApiToken::default();
 590
 591        let (reject_tx, reject_rx) = mpsc::unbounded();
 592        cx.background_spawn({
 593            let client = client.clone();
 594            let llm_token = llm_token.clone();
 595            let app_version = AppVersion::global(cx);
 596            let background_executor = cx.background_executor().clone();
 597            async move {
 598                Self::handle_rejected_predictions(
 599                    reject_rx,
 600                    client,
 601                    llm_token,
 602                    app_version,
 603                    background_executor,
 604                )
 605                .await
 606            }
 607        })
 608        .detach();
 609
 610        let mut this = Self {
 611            projects: HashMap::default(),
 612            client,
 613            user_store,
 614            use_context: false,
 615            llm_token,
 616            _llm_token_subscription: cx.subscribe(
 617                &refresh_llm_token_listener,
 618                |this, _listener, _event, cx| {
 619                    let client = this.client.clone();
 620                    let llm_token = this.llm_token.clone();
 621                    cx.spawn(async move |_this, _cx| {
 622                        llm_token.refresh(&client).await?;
 623                        anyhow::Ok(())
 624                    })
 625                    .detach_and_log_err(cx);
 626                },
 627            ),
 628            update_required: false,
 629            edit_prediction_model: EditPredictionModel::Zeta2 {
 630                version: Default::default(),
 631            },
 632            sweep_ai: SweepAi::new(cx),
 633            mercury: Mercury::new(cx),
 634
 635            data_collection_choice,
 636            reject_predictions_tx: reject_tx,
 637            rated_predictions: Default::default(),
 638            shown_predictions: Default::default(),
 639            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 640                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 641                Err(_) => None,
 642            },
 643        };
 644
 645        this.configure_context_retrieval(cx);
 646        let weak_this = cx.weak_entity();
 647        cx.on_flags_ready(move |_, cx| {
 648            weak_this
 649                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 650                .ok();
 651        })
 652        .detach();
 653        cx.observe_global::<SettingsStore>(|this, cx| {
 654            this.configure_context_retrieval(cx);
 655        })
 656        .detach();
 657
 658        this
 659    }
 660
 661    #[cfg(test)]
 662    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 663        self.custom_predict_edits_url = Some(url.into());
 664    }
 665
 666    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 667        self.edit_prediction_model = model;
 668    }
 669
 670    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 671        self.sweep_ai.api_token.read(cx).has_key()
 672    }
 673
 674    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 675        self.mercury.api_token.read(cx).has_key()
 676    }
 677
 678    pub fn set_use_context(&mut self, use_context: bool) {
 679        self.use_context = use_context;
 680    }
 681
 682    pub fn clear_history(&mut self) {
 683        for project_state in self.projects.values_mut() {
 684            project_state.events.clear();
 685            project_state.last_event.take();
 686        }
 687    }
 688
 689    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 690        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 691            project_state.events.clear();
 692            project_state.last_event.take();
 693        }
 694    }
 695
 696    pub fn edit_history_for_project(
 697        &self,
 698        project: &Entity<Project>,
 699        cx: &App,
 700    ) -> Vec<StoredEvent> {
 701        self.projects
 702            .get(&project.entity_id())
 703            .map(|project_state| project_state.events(cx))
 704            .unwrap_or_default()
 705    }
 706
 707    pub fn edit_history_for_project_with_pause_split_last_event(
 708        &self,
 709        project: &Entity<Project>,
 710        cx: &App,
 711    ) -> Vec<StoredEvent> {
 712        self.projects
 713            .get(&project.entity_id())
 714            .map(|project_state| project_state.events_split_by_pause(cx))
 715            .unwrap_or_default()
 716    }
 717
 718    pub fn context_for_project<'a>(
 719        &'a self,
 720        project: &Entity<Project>,
 721        cx: &'a mut App,
 722    ) -> Vec<RelatedFile> {
 723        self.projects
 724            .get(&project.entity_id())
 725            .map(|project| {
 726                project
 727                    .context
 728                    .update(cx, |context, cx| context.related_files(cx))
 729            })
 730            .unwrap_or_default()
 731    }
 732
 733    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 734        self.projects
 735            .get(&project.entity_id())
 736            .and_then(|project| project.copilot.clone())
 737    }
 738
 739    pub fn start_copilot_for_project(
 740        &mut self,
 741        project: &Entity<Project>,
 742        cx: &mut Context<Self>,
 743    ) -> Option<Entity<Copilot>> {
 744        let state = self.get_or_init_project(project, cx);
 745
 746        if state.copilot.is_some() {
 747            return state.copilot.clone();
 748        }
 749        let _project = project.clone();
 750        let project = project.read(cx);
 751
 752        let node = project.node_runtime().cloned();
 753        if let Some(node) = node {
 754            let next_id = project.languages().next_language_server_id();
 755            let fs = project.fs().clone();
 756
 757            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 758            state.copilot = Some(copilot.clone());
 759            Some(copilot)
 760        } else {
 761            None
 762        }
 763    }
 764
 765    pub fn context_for_project_with_buffers<'a>(
 766        &'a self,
 767        project: &Entity<Project>,
 768        cx: &'a mut App,
 769    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 770        self.projects
 771            .get(&project.entity_id())
 772            .map(|project| {
 773                project
 774                    .context
 775                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 776            })
 777            .unwrap_or_default()
 778    }
 779
 780    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 781        if matches!(
 782            self.edit_prediction_model,
 783            EditPredictionModel::Zeta2 { .. }
 784        ) {
 785            self.user_store.read(cx).edit_prediction_usage()
 786        } else {
 787            None
 788        }
 789    }
 790
 791    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 792        self.get_or_init_project(project, cx);
 793    }
 794
 795    pub fn register_buffer(
 796        &mut self,
 797        buffer: &Entity<Buffer>,
 798        project: &Entity<Project>,
 799        cx: &mut Context<Self>,
 800    ) {
 801        let project_state = self.get_or_init_project(project, cx);
 802        Self::register_buffer_impl(project_state, buffer, project, cx);
 803    }
 804
 805    fn get_or_init_project(
 806        &mut self,
 807        project: &Entity<Project>,
 808        cx: &mut Context<Self>,
 809    ) -> &mut ProjectState {
 810        let entity_id = project.entity_id();
 811        self.projects
 812            .entry(entity_id)
 813            .or_insert_with(|| ProjectState {
 814                context: {
 815                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 816                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 817                        this.handle_excerpt_store_event(entity_id, event);
 818                    })
 819                    .detach();
 820                    related_excerpt_store
 821                },
 822                events: VecDeque::new(),
 823                last_event: None,
 824                recent_paths: VecDeque::new(),
 825                debug_tx: None,
 826                registered_buffers: HashMap::default(),
 827                current_prediction: None,
 828                cancelled_predictions: HashSet::default(),
 829                pending_predictions: ArrayVec::new(),
 830                next_pending_prediction_id: 0,
 831                last_prediction_refresh: None,
 832                license_detection_watchers: HashMap::default(),
 833                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 834                _subscription: cx.subscribe(&project, Self::handle_project_event),
 835                copilot: None,
 836            })
 837    }
 838
 839    pub fn remove_project(&mut self, project: &Entity<Project>) {
 840        self.projects.remove(&project.entity_id());
 841    }
 842
 843    fn handle_excerpt_store_event(
 844        &mut self,
 845        project_entity_id: EntityId,
 846        event: &RelatedExcerptStoreEvent,
 847    ) {
 848        if let Some(project_state) = self.projects.get(&project_entity_id) {
 849            if let Some(debug_tx) = project_state.debug_tx.clone() {
 850                match event {
 851                    RelatedExcerptStoreEvent::StartedRefresh => {
 852                        debug_tx
 853                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 854                                ContextRetrievalStartedDebugEvent {
 855                                    project_entity_id: project_entity_id,
 856                                    timestamp: Instant::now(),
 857                                    search_prompt: String::new(),
 858                                },
 859                            ))
 860                            .ok();
 861                    }
 862                    RelatedExcerptStoreEvent::FinishedRefresh {
 863                        cache_hit_count,
 864                        cache_miss_count,
 865                        mean_definition_latency,
 866                        max_definition_latency,
 867                    } => {
 868                        debug_tx
 869                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 870                                ContextRetrievalFinishedDebugEvent {
 871                                    project_entity_id: project_entity_id,
 872                                    timestamp: Instant::now(),
 873                                    metadata: vec![
 874                                        (
 875                                            "Cache Hits",
 876                                            format!(
 877                                                "{}/{}",
 878                                                cache_hit_count,
 879                                                cache_hit_count + cache_miss_count
 880                                            )
 881                                            .into(),
 882                                        ),
 883                                        (
 884                                            "Max LSP Time",
 885                                            format!("{} ms", max_definition_latency.as_millis())
 886                                                .into(),
 887                                        ),
 888                                        (
 889                                            "Mean LSP Time",
 890                                            format!("{} ms", mean_definition_latency.as_millis())
 891                                                .into(),
 892                                        ),
 893                                    ],
 894                                },
 895                            ))
 896                            .ok();
 897                    }
 898                }
 899            }
 900        }
 901    }
 902
 903    pub fn debug_info(
 904        &mut self,
 905        project: &Entity<Project>,
 906        cx: &mut Context<Self>,
 907    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 908        let project_state = self.get_or_init_project(project, cx);
 909        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 910        project_state.debug_tx = Some(debug_watch_tx);
 911        debug_watch_rx
 912    }
 913
 914    fn handle_project_event(
 915        &mut self,
 916        project: Entity<Project>,
 917        event: &project::Event,
 918        cx: &mut Context<Self>,
 919    ) {
 920        // TODO [zeta2] init with recent paths
 921        match event {
 922            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 923                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 924                    return;
 925                };
 926                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 927                if let Some(path) = path {
 928                    if let Some(ix) = project_state
 929                        .recent_paths
 930                        .iter()
 931                        .position(|probe| probe == &path)
 932                    {
 933                        project_state.recent_paths.remove(ix);
 934                    }
 935                    project_state.recent_paths.push_front(path);
 936                }
 937            }
 938            project::Event::DiagnosticsUpdated { .. } => {
 939                if cx.has_flag::<Zeta2FeatureFlag>() {
 940                    self.refresh_prediction_from_diagnostics(project, cx);
 941                }
 942            }
 943            _ => (),
 944        }
 945    }
 946
 947    fn register_buffer_impl<'a>(
 948        project_state: &'a mut ProjectState,
 949        buffer: &Entity<Buffer>,
 950        project: &Entity<Project>,
 951        cx: &mut Context<Self>,
 952    ) -> &'a mut RegisteredBuffer {
 953        let buffer_id = buffer.entity_id();
 954
 955        if let Some(file) = buffer.read(cx).file() {
 956            let worktree_id = file.worktree_id(cx);
 957            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 958                project_state
 959                    .license_detection_watchers
 960                    .entry(worktree_id)
 961                    .or_insert_with(|| {
 962                        let project_entity_id = project.entity_id();
 963                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 964                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 965                            else {
 966                                return;
 967                            };
 968                            project_state
 969                                .license_detection_watchers
 970                                .remove(&worktree_id);
 971                        })
 972                        .detach();
 973                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 974                    });
 975            }
 976        }
 977
 978        match project_state.registered_buffers.entry(buffer_id) {
 979            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 980            hash_map::Entry::Vacant(entry) => {
 981                let buf = buffer.read(cx);
 982                let snapshot = buf.text_snapshot();
 983                let file = buf.file().cloned();
 984                let project_entity_id = project.entity_id();
 985                entry.insert(RegisteredBuffer {
 986                    snapshot,
 987                    file,
 988                    last_position: None,
 989                    _subscriptions: [
 990                        cx.subscribe(buffer, {
 991                            let project = project.downgrade();
 992                            move |this, buffer, event, cx| {
 993                                if let language::BufferEvent::Edited = event
 994                                    && let Some(project) = project.upgrade()
 995                                {
 996                                    this.report_changes_for_buffer(&buffer, &project, cx);
 997                                }
 998                            }
 999                        }),
1000                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1001                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1002                            else {
1003                                return;
1004                            };
1005                            project_state.registered_buffers.remove(&buffer_id);
1006                        }),
1007                    ],
1008                })
1009            }
1010        }
1011    }
1012
1013    fn report_changes_for_buffer(
1014        &mut self,
1015        buffer: &Entity<Buffer>,
1016        project: &Entity<Project>,
1017        cx: &mut Context<Self>,
1018    ) {
1019        let project_state = self.get_or_init_project(project, cx);
1020        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1021
1022        let buf = buffer.read(cx);
1023        let new_file = buf.file().cloned();
1024        let new_snapshot = buf.text_snapshot();
1025        if new_snapshot.version == registered_buffer.snapshot.version {
1026            return;
1027        }
1028
1029        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1030        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1031        let mut num_edits = 0usize;
1032        let mut total_deleted = 0usize;
1033        let mut total_inserted = 0usize;
1034        let mut edit_range: Option<Range<Anchor>> = None;
1035        let mut last_offset: Option<usize> = None;
1036
1037        for (edit, anchor_range) in
1038            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1039        {
1040            num_edits += 1;
1041            total_deleted += edit.old.len();
1042            total_inserted += edit.new.len();
1043            edit_range = Some(match edit_range {
1044                None => anchor_range,
1045                Some(acc) => acc.start..anchor_range.end,
1046            });
1047            last_offset = Some(edit.new.end);
1048        }
1049
1050        if num_edits > 0 {
1051            let action_type = match (total_deleted, total_inserted, num_edits) {
1052                (0, ins, n) if ins == n => UserActionType::InsertChar,
1053                (0, _, _) => UserActionType::InsertSelection,
1054                (del, 0, n) if del == n => UserActionType::DeleteChar,
1055                (_, 0, _) => UserActionType::DeleteSelection,
1056                (_, ins, n) if ins == n => UserActionType::InsertChar,
1057                (_, _, _) => UserActionType::InsertSelection,
1058            };
1059
1060            if let Some(offset) = last_offset {
1061                let point = new_snapshot.offset_to_point(offset);
1062                let timestamp_epoch_ms = SystemTime::now()
1063                    .duration_since(UNIX_EPOCH)
1064                    .map(|d| d.as_millis() as u64)
1065                    .unwrap_or(0);
1066                project_state.record_user_action(UserActionRecord {
1067                    action_type,
1068                    buffer_id: buffer.entity_id(),
1069                    line_number: point.row,
1070                    offset,
1071                    timestamp_epoch_ms,
1072                });
1073            }
1074        }
1075
1076        let events = &mut project_state.events;
1077
1078        let now = cx.background_executor().now();
1079        if let Some(last_event) = project_state.last_event.as_mut() {
1080            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1081                == last_event.new_snapshot.remote_id()
1082                && old_snapshot.version == last_event.new_snapshot.version;
1083
1084            let should_coalesce = is_next_snapshot_of_same_buffer
1085                && edit_range
1086                    .as_ref()
1087                    .zip(last_event.edit_range.as_ref())
1088                    .is_some_and(|(a, b)| {
1089                        let a = a.to_point(&new_snapshot);
1090                        let b = b.to_point(&new_snapshot);
1091                        if a.start > b.end {
1092                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1093                        } else if b.start > a.end {
1094                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1095                        } else {
1096                            true
1097                        }
1098                    });
1099
1100            if should_coalesce {
1101                let pause_elapsed = last_event
1102                    .last_edit_time
1103                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1104                    .unwrap_or(false);
1105                if pause_elapsed {
1106                    last_event.snapshot_after_last_editing_pause =
1107                        Some(last_event.new_snapshot.clone());
1108                }
1109
1110                last_event.edit_range = edit_range;
1111                last_event.new_snapshot = new_snapshot;
1112                last_event.last_edit_time = Some(now);
1113                return;
1114            }
1115        }
1116
1117        if let Some(event) = project_state.last_event.take() {
1118            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1119                if events.len() + 1 >= EVENT_COUNT_MAX {
1120                    events.pop_front();
1121                }
1122                events.push_back(event);
1123            }
1124        }
1125
1126        project_state.last_event = Some(LastEvent {
1127            old_file,
1128            new_file,
1129            old_snapshot,
1130            new_snapshot,
1131            edit_range,
1132            snapshot_after_last_editing_pause: None,
1133            last_edit_time: Some(now),
1134        });
1135    }
1136
1137    fn prediction_at(
1138        &mut self,
1139        buffer: &Entity<Buffer>,
1140        position: Option<language::Anchor>,
1141        project: &Entity<Project>,
1142        cx: &App,
1143    ) -> Option<BufferEditPrediction<'_>> {
1144        let project_state = self.projects.get_mut(&project.entity_id())?;
1145        if let Some(position) = position
1146            && let Some(buffer) = project_state
1147                .registered_buffers
1148                .get_mut(&buffer.entity_id())
1149        {
1150            buffer.last_position = Some(position);
1151        }
1152
1153        let CurrentEditPrediction {
1154            requested_by,
1155            prediction,
1156            ..
1157        } = project_state.current_prediction.as_ref()?;
1158
1159        if prediction.targets_buffer(buffer.read(cx)) {
1160            Some(BufferEditPrediction::Local { prediction })
1161        } else {
1162            let show_jump = match requested_by {
1163                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1164                    requested_by_buffer_id == &buffer.entity_id()
1165                }
1166                PredictionRequestedBy::DiagnosticsUpdate => true,
1167            };
1168
1169            if show_jump {
1170                Some(BufferEditPrediction::Jump { prediction })
1171            } else {
1172                None
1173            }
1174        }
1175    }
1176
1177    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1178        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1179            return;
1180        };
1181
1182        let Some(current_prediction) = project_state.current_prediction.take() else {
1183            return;
1184        };
1185
1186        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1187            project_state.cancel_pending_prediction(pending_prediction, cx);
1188        }
1189
1190        match self.edit_prediction_model {
1191            EditPredictionModel::Sweep => {
1192                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1193            }
1194            EditPredictionModel::Mercury => {}
1195            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1196                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1197            }
1198        }
1199    }
1200
1201    async fn handle_rejected_predictions(
1202        rx: UnboundedReceiver<EditPredictionRejection>,
1203        client: Arc<Client>,
1204        llm_token: LlmApiToken,
1205        app_version: Version,
1206        background_executor: BackgroundExecutor,
1207    ) {
1208        let mut rx = std::pin::pin!(rx.peekable());
1209        let mut batched = Vec::new();
1210
1211        while let Some(rejection) = rx.next().await {
1212            batched.push(rejection);
1213
1214            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1215                select_biased! {
1216                    next = rx.as_mut().peek().fuse() => {
1217                        if next.is_some() {
1218                            continue;
1219                        }
1220                    }
1221                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1222                }
1223            }
1224
1225            let url = client
1226                .http_client()
1227                .build_zed_llm_url("/predict_edits/reject", &[])
1228                .unwrap();
1229
1230            let flush_count = batched
1231                .len()
1232                // in case items have accumulated after failure
1233                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1234            let start = batched.len() - flush_count;
1235
1236            let body = RejectEditPredictionsBodyRef {
1237                rejections: &batched[start..],
1238            };
1239
1240            let result = Self::send_api_request::<()>(
1241                |builder| {
1242                    let req = builder
1243                        .uri(url.as_ref())
1244                        .body(serde_json::to_string(&body)?.into());
1245                    anyhow::Ok(req?)
1246                },
1247                client.clone(),
1248                llm_token.clone(),
1249                app_version.clone(),
1250                true,
1251            )
1252            .await;
1253
1254            if result.log_err().is_some() {
1255                batched.drain(start..);
1256            }
1257        }
1258    }
1259
1260    fn reject_current_prediction(
1261        &mut self,
1262        reason: EditPredictionRejectReason,
1263        project: &Entity<Project>,
1264    ) {
1265        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1266            project_state.pending_predictions.clear();
1267            if let Some(prediction) = project_state.current_prediction.take() {
1268                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1269            }
1270        };
1271    }
1272
1273    fn did_show_current_prediction(
1274        &mut self,
1275        project: &Entity<Project>,
1276        display_type: edit_prediction_types::SuggestionDisplayType,
1277        cx: &mut Context<Self>,
1278    ) {
1279        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1280            return;
1281        };
1282
1283        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1284            return;
1285        };
1286
1287        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1288        let previous_shown_with = current_prediction.shown_with;
1289
1290        if previous_shown_with.is_none() || !is_jump {
1291            current_prediction.shown_with = Some(display_type);
1292        }
1293
1294        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1295
1296        if is_first_non_jump_show {
1297            current_prediction.was_shown = true;
1298        }
1299
1300        let display_type_changed = previous_shown_with != Some(display_type);
1301
1302        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1303            sweep_ai::edit_prediction_shown(
1304                &self.sweep_ai,
1305                self.client.clone(),
1306                &current_prediction.prediction,
1307                display_type,
1308                cx,
1309            );
1310        }
1311
1312        if is_first_non_jump_show {
1313            self.shown_predictions
1314                .push_front(current_prediction.prediction.clone());
1315            if self.shown_predictions.len() > 50 {
1316                let completion = self.shown_predictions.pop_back().unwrap();
1317                self.rated_predictions.remove(&completion.id);
1318            }
1319        }
1320    }
1321
1322    fn reject_prediction(
1323        &mut self,
1324        prediction_id: EditPredictionId,
1325        reason: EditPredictionRejectReason,
1326        was_shown: bool,
1327    ) {
1328        match self.edit_prediction_model {
1329            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1330                if self.custom_predict_edits_url.is_some() {
1331                    return;
1332                }
1333            }
1334            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1335        }
1336
1337        self.reject_predictions_tx
1338            .unbounded_send(EditPredictionRejection {
1339                request_id: prediction_id.to_string(),
1340                reason,
1341                was_shown,
1342            })
1343            .log_err();
1344    }
1345
1346    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1347        self.projects
1348            .get(&project.entity_id())
1349            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1350    }
1351
1352    pub fn refresh_prediction_from_buffer(
1353        &mut self,
1354        project: Entity<Project>,
1355        buffer: Entity<Buffer>,
1356        position: language::Anchor,
1357        cx: &mut Context<Self>,
1358    ) {
1359        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1360            let Some(request_task) = this
1361                .update(cx, |this, cx| {
1362                    this.request_prediction(
1363                        &project,
1364                        &buffer,
1365                        position,
1366                        PredictEditsRequestTrigger::Other,
1367                        cx,
1368                    )
1369                })
1370                .log_err()
1371            else {
1372                return Task::ready(anyhow::Ok(None));
1373            };
1374
1375            cx.spawn(async move |_cx| {
1376                request_task.await.map(|prediction_result| {
1377                    prediction_result.map(|prediction_result| {
1378                        (
1379                            prediction_result,
1380                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1381                        )
1382                    })
1383                })
1384            })
1385        })
1386    }
1387
1388    pub fn refresh_prediction_from_diagnostics(
1389        &mut self,
1390        project: Entity<Project>,
1391        cx: &mut Context<Self>,
1392    ) {
1393        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1394            return;
1395        };
1396
1397        // Prefer predictions from buffer
1398        if project_state.current_prediction.is_some() {
1399            return;
1400        };
1401
1402        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1403            let Some((active_buffer, snapshot, cursor_point)) = this
1404                .read_with(cx, |this, cx| {
1405                    let project_state = this.projects.get(&project.entity_id())?;
1406                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1407                    let snapshot = buffer.read(cx).snapshot();
1408
1409                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1410                        return None;
1411                    }
1412
1413                    let cursor_point = position
1414                        .map(|pos| pos.to_point(&snapshot))
1415                        .unwrap_or_default();
1416
1417                    Some((buffer, snapshot, cursor_point))
1418                })
1419                .log_err()
1420                .flatten()
1421            else {
1422                return Task::ready(anyhow::Ok(None));
1423            };
1424
1425            cx.spawn(async move |cx| {
1426                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1427                    active_buffer,
1428                    &snapshot,
1429                    Default::default(),
1430                    cursor_point,
1431                    &project,
1432                    cx,
1433                )
1434                .await?
1435                else {
1436                    return anyhow::Ok(None);
1437                };
1438
1439                let Some(prediction_result) = this
1440                    .update(cx, |this, cx| {
1441                        this.request_prediction(
1442                            &project,
1443                            &jump_buffer,
1444                            jump_position,
1445                            PredictEditsRequestTrigger::Diagnostics,
1446                            cx,
1447                        )
1448                    })?
1449                    .await?
1450                else {
1451                    return anyhow::Ok(None);
1452                };
1453
1454                this.update(cx, |this, cx| {
1455                    Some((
1456                        if this
1457                            .get_or_init_project(&project, cx)
1458                            .current_prediction
1459                            .is_none()
1460                        {
1461                            prediction_result
1462                        } else {
1463                            EditPredictionResult {
1464                                id: prediction_result.id,
1465                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1466                            }
1467                        },
1468                        PredictionRequestedBy::DiagnosticsUpdate,
1469                    ))
1470                })
1471            })
1472        });
1473    }
1474
1475    fn predictions_enabled_at(
1476        snapshot: &BufferSnapshot,
1477        position: Option<language::Anchor>,
1478        cx: &App,
1479    ) -> bool {
1480        let file = snapshot.file();
1481        let all_settings = all_language_settings(file, cx);
1482        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1483            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1484        {
1485            return false;
1486        }
1487
1488        if let Some(last_position) = position {
1489            let settings = snapshot.settings_at(last_position, cx);
1490
1491            if !settings.edit_predictions_disabled_in.is_empty()
1492                && let Some(scope) = snapshot.language_scope_at(last_position)
1493                && let Some(scope_name) = scope.override_name()
1494                && settings
1495                    .edit_predictions_disabled_in
1496                    .iter()
1497                    .any(|s| s == scope_name)
1498            {
1499                return false;
1500            }
1501        }
1502
1503        true
1504    }
1505
1506    #[cfg(not(test))]
1507    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1508    #[cfg(test)]
1509    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1510
1511    fn queue_prediction_refresh(
1512        &mut self,
1513        project: Entity<Project>,
1514        throttle_entity: EntityId,
1515        cx: &mut Context<Self>,
1516        do_refresh: impl FnOnce(
1517            WeakEntity<Self>,
1518            &mut AsyncApp,
1519        )
1520            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1521        + 'static,
1522    ) {
1523        let project_state = self.get_or_init_project(&project, cx);
1524        let pending_prediction_id = project_state.next_pending_prediction_id;
1525        project_state.next_pending_prediction_id += 1;
1526        let last_request = project_state.last_prediction_refresh;
1527
1528        let task = cx.spawn(async move |this, cx| {
1529            if let Some((last_entity, last_timestamp)) = last_request
1530                && throttle_entity == last_entity
1531                && let Some(timeout) =
1532                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1533            {
1534                cx.background_executor().timer(timeout).await;
1535            }
1536
1537            // If this task was cancelled before the throttle timeout expired,
1538            // do not perform a request.
1539            let mut is_cancelled = true;
1540            this.update(cx, |this, cx| {
1541                let project_state = this.get_or_init_project(&project, cx);
1542                if !project_state
1543                    .cancelled_predictions
1544                    .remove(&pending_prediction_id)
1545                {
1546                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1547                    is_cancelled = false;
1548                }
1549            })
1550            .ok();
1551            if is_cancelled {
1552                return None;
1553            }
1554
1555            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1556            let new_prediction_id = new_prediction_result
1557                .as_ref()
1558                .map(|(prediction, _)| prediction.id.clone());
1559
1560            // When a prediction completes, remove it from the pending list, and cancel
1561            // any pending predictions that were enqueued before it.
1562            this.update(cx, |this, cx| {
1563                let project_state = this.get_or_init_project(&project, cx);
1564
1565                let is_cancelled = project_state
1566                    .cancelled_predictions
1567                    .remove(&pending_prediction_id);
1568
1569                let new_current_prediction = if !is_cancelled
1570                    && let Some((prediction_result, requested_by)) = new_prediction_result
1571                {
1572                    match prediction_result.prediction {
1573                        Ok(prediction) => {
1574                            let new_prediction = CurrentEditPrediction {
1575                                requested_by,
1576                                prediction,
1577                                was_shown: false,
1578                                shown_with: None,
1579                            };
1580
1581                            if let Some(current_prediction) =
1582                                project_state.current_prediction.as_ref()
1583                            {
1584                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1585                                {
1586                                    this.reject_current_prediction(
1587                                        EditPredictionRejectReason::Replaced,
1588                                        &project,
1589                                    );
1590
1591                                    Some(new_prediction)
1592                                } else {
1593                                    this.reject_prediction(
1594                                        new_prediction.prediction.id,
1595                                        EditPredictionRejectReason::CurrentPreferred,
1596                                        false,
1597                                    );
1598                                    None
1599                                }
1600                            } else {
1601                                Some(new_prediction)
1602                            }
1603                        }
1604                        Err(reject_reason) => {
1605                            this.reject_prediction(prediction_result.id, reject_reason, false);
1606                            None
1607                        }
1608                    }
1609                } else {
1610                    None
1611                };
1612
1613                let project_state = this.get_or_init_project(&project, cx);
1614
1615                if let Some(new_prediction) = new_current_prediction {
1616                    project_state.current_prediction = Some(new_prediction);
1617                }
1618
1619                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1620                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1621                    if pending_prediction.id == pending_prediction_id {
1622                        pending_predictions.remove(ix);
1623                        for pending_prediction in pending_predictions.drain(0..ix) {
1624                            project_state.cancel_pending_prediction(pending_prediction, cx)
1625                        }
1626                        break;
1627                    }
1628                }
1629                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1630                cx.notify();
1631            })
1632            .ok();
1633
1634            new_prediction_id
1635        });
1636
1637        if project_state.pending_predictions.len() <= 1 {
1638            project_state.pending_predictions.push(PendingPrediction {
1639                id: pending_prediction_id,
1640                task,
1641            });
1642        } else if project_state.pending_predictions.len() == 2 {
1643            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1644            project_state.pending_predictions.push(PendingPrediction {
1645                id: pending_prediction_id,
1646                task,
1647            });
1648            project_state.cancel_pending_prediction(pending_prediction, cx);
1649        }
1650    }
1651
1652    pub fn request_prediction(
1653        &mut self,
1654        project: &Entity<Project>,
1655        active_buffer: &Entity<Buffer>,
1656        position: language::Anchor,
1657        trigger: PredictEditsRequestTrigger,
1658        cx: &mut Context<Self>,
1659    ) -> Task<Result<Option<EditPredictionResult>>> {
1660        self.request_prediction_internal(
1661            project.clone(),
1662            active_buffer.clone(),
1663            position,
1664            trigger,
1665            cx.has_flag::<Zeta2FeatureFlag>(),
1666            cx,
1667        )
1668    }
1669
1670    fn request_prediction_internal(
1671        &mut self,
1672        project: Entity<Project>,
1673        active_buffer: Entity<Buffer>,
1674        position: language::Anchor,
1675        trigger: PredictEditsRequestTrigger,
1676        allow_jump: bool,
1677        cx: &mut Context<Self>,
1678    ) -> Task<Result<Option<EditPredictionResult>>> {
1679        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1680
1681        self.get_or_init_project(&project, cx);
1682        let project_state = self.projects.get(&project.entity_id()).unwrap();
1683        let stored_events = project_state.events(cx);
1684        let has_events = !stored_events.is_empty();
1685        let events: Vec<Arc<zeta_prompt::Event>> =
1686            stored_events.into_iter().map(|e| e.event).collect();
1687        let debug_tx = project_state.debug_tx.clone();
1688
1689        let snapshot = active_buffer.read(cx).snapshot();
1690        let cursor_point = position.to_point(&snapshot);
1691        let current_offset = position.to_offset(&snapshot);
1692
1693        let mut user_actions: Vec<UserActionRecord> =
1694            project_state.user_actions.iter().cloned().collect();
1695
1696        if let Some(last_action) = user_actions.last() {
1697            if last_action.buffer_id == active_buffer.entity_id()
1698                && current_offset != last_action.offset
1699            {
1700                let timestamp_epoch_ms = SystemTime::now()
1701                    .duration_since(UNIX_EPOCH)
1702                    .map(|d| d.as_millis() as u64)
1703                    .unwrap_or(0);
1704                user_actions.push(UserActionRecord {
1705                    action_type: UserActionType::CursorMovement,
1706                    buffer_id: active_buffer.entity_id(),
1707                    line_number: cursor_point.row,
1708                    offset: current_offset,
1709                    timestamp_epoch_ms,
1710                });
1711            }
1712        }
1713        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1714        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1715        let diagnostic_search_range =
1716            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1717
1718        let related_files = if self.use_context {
1719            self.context_for_project(&project, cx)
1720        } else {
1721            Vec::new()
1722        };
1723
1724        let inputs = EditPredictionModelInput {
1725            project: project.clone(),
1726            buffer: active_buffer.clone(),
1727            snapshot: snapshot.clone(),
1728            position,
1729            events,
1730            related_files,
1731            recent_paths: project_state.recent_paths.clone(),
1732            trigger,
1733            diagnostic_search_range: diagnostic_search_range.clone(),
1734            debug_tx,
1735            user_actions,
1736        };
1737
1738        let can_collect_example = snapshot
1739            .file()
1740            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1741            && self.can_collect_events(&inputs.events, cx)
1742            && self.can_collect_related_files(&project, cx);
1743
1744        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1745            let events_for_capture =
1746                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1747            let related_files_for_capture = inputs.related_files.clone();
1748            if let Some(example_task) = capture_example::capture_example(
1749                project.clone(),
1750                active_buffer.clone(),
1751                position,
1752                events_for_capture,
1753                related_files_for_capture,
1754                false,
1755                cx,
1756            ) {
1757                cx.spawn(async move |_this, _cx| {
1758                    let example = example_task.await?;
1759                    telemetry::event!("Edit Prediction Example Captured", example = example);
1760                    anyhow::Ok(())
1761                })
1762                .detach_and_log_err(cx);
1763            }
1764        }
1765        let task = match self.edit_prediction_model {
1766            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1767            EditPredictionModel::Zeta2 { version } => {
1768                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1769            }
1770            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1771            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1772        };
1773
1774        cx.spawn(async move |this, cx| {
1775            let prediction = task.await?;
1776
1777            if prediction.is_none() && allow_jump {
1778                let cursor_point = position.to_point(&snapshot);
1779                if has_events
1780                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1781                        active_buffer.clone(),
1782                        &snapshot,
1783                        diagnostic_search_range,
1784                        cursor_point,
1785                        &project,
1786                        cx,
1787                    )
1788                    .await?
1789                {
1790                    return this
1791                        .update(cx, |this, cx| {
1792                            this.request_prediction_internal(
1793                                project,
1794                                jump_buffer,
1795                                jump_position,
1796                                trigger,
1797                                false,
1798                                cx,
1799                            )
1800                        })?
1801                        .await;
1802                }
1803
1804                return anyhow::Ok(None);
1805            }
1806
1807            Ok(prediction)
1808        })
1809    }
1810
1811    async fn next_diagnostic_location(
1812        active_buffer: Entity<Buffer>,
1813        active_buffer_snapshot: &BufferSnapshot,
1814        active_buffer_diagnostic_search_range: Range<Point>,
1815        active_buffer_cursor_point: Point,
1816        project: &Entity<Project>,
1817        cx: &mut AsyncApp,
1818    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1819        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1820        let mut jump_location = active_buffer_snapshot
1821            .diagnostic_groups(None)
1822            .into_iter()
1823            .filter_map(|(_, group)| {
1824                let range = &group.entries[group.primary_ix]
1825                    .range
1826                    .to_point(&active_buffer_snapshot);
1827                if range.overlaps(&active_buffer_diagnostic_search_range) {
1828                    None
1829                } else {
1830                    Some(range.start)
1831                }
1832            })
1833            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1834            .map(|position| {
1835                (
1836                    active_buffer.clone(),
1837                    active_buffer_snapshot.anchor_before(position),
1838                )
1839            });
1840
1841        if jump_location.is_none() {
1842            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1843                let file = buffer.file()?;
1844
1845                Some(ProjectPath {
1846                    worktree_id: file.worktree_id(cx),
1847                    path: file.path().clone(),
1848                })
1849            });
1850
1851            let buffer_task = project.update(cx, |project, cx| {
1852                let (path, _, _) = project
1853                    .diagnostic_summaries(false, cx)
1854                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1855                    .max_by_key(|(path, _, _)| {
1856                        // find the buffer with errors that shares most parent directories
1857                        path.path
1858                            .components()
1859                            .zip(
1860                                active_buffer_path
1861                                    .as_ref()
1862                                    .map(|p| p.path.components())
1863                                    .unwrap_or_default(),
1864                            )
1865                            .take_while(|(a, b)| a == b)
1866                            .count()
1867                    })?;
1868
1869                Some(project.open_buffer(path, cx))
1870            });
1871
1872            if let Some(buffer_task) = buffer_task {
1873                let closest_buffer = buffer_task.await?;
1874
1875                jump_location = closest_buffer
1876                    .read_with(cx, |buffer, _cx| {
1877                        buffer
1878                            .buffer_diagnostics(None)
1879                            .into_iter()
1880                            .min_by_key(|entry| entry.diagnostic.severity)
1881                            .map(|entry| entry.range.start)
1882                    })
1883                    .map(|position| (closest_buffer, position));
1884            }
1885        }
1886
1887        anyhow::Ok(jump_location)
1888    }
1889
1890    async fn send_raw_llm_request(
1891        request: RawCompletionRequest,
1892        client: Arc<Client>,
1893        custom_url: Option<Arc<Url>>,
1894        llm_token: LlmApiToken,
1895        app_version: Version,
1896    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1897        let url = if let Some(custom_url) = custom_url {
1898            custom_url.as_ref().clone()
1899        } else {
1900            client
1901                .http_client()
1902                .build_zed_llm_url("/predict_edits/raw", &[])?
1903        };
1904
1905        Self::send_api_request(
1906            |builder| {
1907                let req = builder
1908                    .uri(url.as_ref())
1909                    .body(serde_json::to_string(&request)?.into());
1910                Ok(req?)
1911            },
1912            client,
1913            llm_token,
1914            app_version,
1915            true,
1916        )
1917        .await
1918    }
1919
1920    pub(crate) async fn send_v3_request(
1921        input: ZetaPromptInput,
1922        prompt_version: ZetaVersion,
1923        client: Arc<Client>,
1924        llm_token: LlmApiToken,
1925        app_version: Version,
1926        trigger: PredictEditsRequestTrigger,
1927    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1928        let url = client
1929            .http_client()
1930            .build_zed_llm_url("/predict_edits/v3", &[])?;
1931
1932        let request = PredictEditsV3Request {
1933            input,
1934            model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1935            prompt_version,
1936            trigger,
1937        };
1938
1939        Self::send_api_request(
1940            |builder| {
1941                let req = builder
1942                    .uri(url.as_ref())
1943                    .body(serde_json::to_string(&request)?.into());
1944                Ok(req?)
1945            },
1946            client,
1947            llm_token,
1948            app_version,
1949            true,
1950        )
1951        .await
1952    }
1953
1954    fn handle_api_response<T>(
1955        this: &WeakEntity<Self>,
1956        response: Result<(T, Option<EditPredictionUsage>)>,
1957        cx: &mut gpui::AsyncApp,
1958    ) -> Result<T> {
1959        match response {
1960            Ok((data, usage)) => {
1961                if let Some(usage) = usage {
1962                    this.update(cx, |this, cx| {
1963                        this.user_store.update(cx, |user_store, cx| {
1964                            user_store.update_edit_prediction_usage(usage, cx);
1965                        });
1966                    })
1967                    .ok();
1968                }
1969                Ok(data)
1970            }
1971            Err(err) => {
1972                if err.is::<ZedUpdateRequiredError>() {
1973                    cx.update(|cx| {
1974                        this.update(cx, |this, _cx| {
1975                            this.update_required = true;
1976                        })
1977                        .ok();
1978
1979                        let error_message: SharedString = err.to_string().into();
1980                        show_app_notification(
1981                            NotificationId::unique::<ZedUpdateRequiredError>(),
1982                            cx,
1983                            move |cx| {
1984                                cx.new(|cx| {
1985                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1986                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1987                                })
1988                            },
1989                        );
1990                    });
1991                }
1992                Err(err)
1993            }
1994        }
1995    }
1996
1997    async fn send_api_request<Res>(
1998        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1999        client: Arc<Client>,
2000        llm_token: LlmApiToken,
2001        app_version: Version,
2002        require_auth: bool,
2003    ) -> Result<(Res, Option<EditPredictionUsage>)>
2004    where
2005        Res: DeserializeOwned,
2006    {
2007        let http_client = client.http_client();
2008
2009        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2010            Some(custom_token)
2011        } else if require_auth {
2012            Some(llm_token.acquire(&client).await?)
2013        } else {
2014            llm_token.acquire(&client).await.ok()
2015        };
2016        let mut did_retry = false;
2017
2018        loop {
2019            let request_builder = http_client::Request::builder().method(Method::POST);
2020
2021            let mut request_builder = request_builder
2022                .header("Content-Type", "application/json")
2023                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2024
2025            // Only add Authorization header if we have a token
2026            if let Some(ref token_value) = token {
2027                request_builder =
2028                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2029            }
2030
2031            let request = build(request_builder)?;
2032
2033            let mut response = http_client.send(request).await?;
2034
2035            if let Some(minimum_required_version) = response
2036                .headers()
2037                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2038                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2039            {
2040                anyhow::ensure!(
2041                    app_version >= minimum_required_version,
2042                    ZedUpdateRequiredError {
2043                        minimum_version: minimum_required_version
2044                    }
2045                );
2046            }
2047
2048            if response.status().is_success() {
2049                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2050
2051                let mut body = Vec::new();
2052                response.body_mut().read_to_end(&mut body).await?;
2053                return Ok((serde_json::from_slice(&body)?, usage));
2054            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2055                did_retry = true;
2056                token = Some(llm_token.refresh(&client).await?);
2057            } else {
2058                let mut body = String::new();
2059                response.body_mut().read_to_string(&mut body).await?;
2060                anyhow::bail!(
2061                    "Request failed with status: {:?}\nBody: {}",
2062                    response.status(),
2063                    body
2064                );
2065            }
2066        }
2067    }
2068
2069    pub fn refresh_context(
2070        &mut self,
2071        project: &Entity<Project>,
2072        buffer: &Entity<language::Buffer>,
2073        cursor_position: language::Anchor,
2074        cx: &mut Context<Self>,
2075    ) {
2076        if self.use_context {
2077            self.get_or_init_project(project, cx)
2078                .context
2079                .update(cx, |store, cx| {
2080                    store.refresh(buffer.clone(), cursor_position, cx);
2081                });
2082        }
2083    }
2084
2085    #[cfg(feature = "cli-support")]
2086    pub fn set_context_for_buffer(
2087        &mut self,
2088        project: &Entity<Project>,
2089        related_files: Vec<RelatedFile>,
2090        cx: &mut Context<Self>,
2091    ) {
2092        self.get_or_init_project(project, cx)
2093            .context
2094            .update(cx, |store, cx| {
2095                store.set_related_files(related_files, cx);
2096            });
2097    }
2098
2099    #[cfg(feature = "cli-support")]
2100    pub fn set_recent_paths_for_project(
2101        &mut self,
2102        project: &Entity<Project>,
2103        paths: impl IntoIterator<Item = project::ProjectPath>,
2104        cx: &mut Context<Self>,
2105    ) {
2106        let project_state = self.get_or_init_project(project, cx);
2107        project_state.recent_paths = paths.into_iter().collect();
2108    }
2109
2110    fn is_file_open_source(
2111        &self,
2112        project: &Entity<Project>,
2113        file: &Arc<dyn File>,
2114        cx: &App,
2115    ) -> bool {
2116        if !file.is_local() || file.is_private() {
2117            return false;
2118        }
2119        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2120            return false;
2121        };
2122        project_state
2123            .license_detection_watchers
2124            .get(&file.worktree_id(cx))
2125            .as_ref()
2126            .is_some_and(|watcher| watcher.is_project_open_source())
2127    }
2128
2129    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2130        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2131    }
2132
2133    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2134        if !self.data_collection_choice.is_enabled(cx) {
2135            return false;
2136        }
2137        events.iter().all(|event| {
2138            matches!(
2139                event.as_ref(),
2140                zeta_prompt::Event::BufferChange {
2141                    in_open_source_repo: true,
2142                    ..
2143                }
2144            )
2145        })
2146    }
2147
2148    fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2149        if !self.data_collection_choice.is_enabled(cx) {
2150            return false;
2151        }
2152
2153        let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2154
2155        related_with_buffers.iter().all(|(_, buffer)| {
2156            buffer
2157                .read(cx)
2158                .file()
2159                .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2160        })
2161    }
2162
2163    fn load_data_collection_choice() -> DataCollectionChoice {
2164        let choice = KEY_VALUE_STORE
2165            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2166            .log_err()
2167            .flatten();
2168
2169        match choice.as_deref() {
2170            Some("true") => DataCollectionChoice::Enabled,
2171            Some("false") => DataCollectionChoice::Disabled,
2172            Some(_) => {
2173                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2174                DataCollectionChoice::NotAnswered
2175            }
2176            None => DataCollectionChoice::NotAnswered,
2177        }
2178    }
2179
2180    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2181        self.data_collection_choice = self.data_collection_choice.toggle();
2182        let new_choice = self.data_collection_choice;
2183        let is_enabled = new_choice.is_enabled(cx);
2184        db::write_and_log(cx, move || {
2185            KEY_VALUE_STORE.write_kvp(
2186                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2187                is_enabled.to_string(),
2188            )
2189        });
2190    }
2191
2192    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2193        self.shown_predictions.iter()
2194    }
2195
2196    pub fn shown_completions_len(&self) -> usize {
2197        self.shown_predictions.len()
2198    }
2199
2200    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2201        self.rated_predictions.contains(id)
2202    }
2203
2204    pub fn rate_prediction(
2205        &mut self,
2206        prediction: &EditPrediction,
2207        rating: EditPredictionRating,
2208        feedback: String,
2209        cx: &mut Context<Self>,
2210    ) {
2211        self.rated_predictions.insert(prediction.id.clone());
2212        telemetry::event!(
2213            "Edit Prediction Rated",
2214            rating,
2215            inputs = prediction.inputs,
2216            output = prediction
2217                .edit_preview
2218                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2219            feedback
2220        );
2221        self.client.telemetry().flush_events().detach();
2222        cx.notify();
2223    }
2224
2225    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2226        if cfg!(feature = "cli-support") {
2227            return;
2228        }
2229        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2230            && all_language_settings(None, cx).edit_predictions.use_context;
2231    }
2232}
2233
2234pub(crate) fn filter_redundant_excerpts(
2235    mut related_files: Vec<RelatedFile>,
2236    cursor_path: &Path,
2237    cursor_row_range: Range<u32>,
2238) -> Vec<RelatedFile> {
2239    for file in &mut related_files {
2240        if file.path.as_ref() == cursor_path {
2241            file.excerpts.retain(|excerpt| {
2242                excerpt.row_range.start < cursor_row_range.start
2243                    || excerpt.row_range.end > cursor_row_range.end
2244            });
2245        }
2246    }
2247    related_files.retain(|file| !file.excerpts.is_empty());
2248    related_files
2249}
2250
2251#[derive(Error, Debug)]
2252#[error(
2253    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2254)]
2255pub struct ZedUpdateRequiredError {
2256    minimum_version: Version,
2257}
2258
2259#[derive(Debug, Clone, Copy)]
2260pub enum DataCollectionChoice {
2261    NotAnswered,
2262    Enabled,
2263    Disabled,
2264}
2265
2266impl DataCollectionChoice {
2267    pub fn is_enabled(self, cx: &App) -> bool {
2268        if cx.is_staff() {
2269            return true;
2270        }
2271        match self {
2272            Self::Enabled => true,
2273            Self::NotAnswered | Self::Disabled => false,
2274        }
2275    }
2276
2277    #[must_use]
2278    pub fn toggle(&self) -> DataCollectionChoice {
2279        match self {
2280            Self::Enabled => Self::Disabled,
2281            Self::Disabled => Self::Enabled,
2282            Self::NotAnswered => Self::Enabled,
2283        }
2284    }
2285}
2286
2287impl From<bool> for DataCollectionChoice {
2288    fn from(value: bool) -> Self {
2289        match value {
2290            true => DataCollectionChoice::Enabled,
2291            false => DataCollectionChoice::Disabled,
2292        }
2293    }
2294}
2295
2296struct ZedPredictUpsell;
2297
2298impl Dismissable for ZedPredictUpsell {
2299    const KEY: &'static str = "dismissed-edit-predict-upsell";
2300
2301    fn dismissed() -> bool {
2302        // To make this backwards compatible with older versions of Zed, we
2303        // check if the user has seen the previous Edit Prediction Onboarding
2304        // before, by checking the data collection choice which was written to
2305        // the database once the user clicked on "Accept and Enable"
2306        if KEY_VALUE_STORE
2307            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2308            .log_err()
2309            .is_some_and(|s| s.is_some())
2310        {
2311            return true;
2312        }
2313
2314        KEY_VALUE_STORE
2315            .read_kvp(Self::KEY)
2316            .log_err()
2317            .is_some_and(|s| s.is_some())
2318    }
2319}
2320
2321pub fn should_show_upsell_modal() -> bool {
2322    !ZedPredictUpsell::dismissed()
2323}
2324
2325pub fn init(cx: &mut App) {
2326    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2327        workspace.register_action(
2328            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2329                ZedPredictModal::toggle(
2330                    workspace,
2331                    workspace.user_store().clone(),
2332                    workspace.client().clone(),
2333                    window,
2334                    cx,
2335                )
2336            },
2337        );
2338
2339        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2340            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2341                settings
2342                    .project
2343                    .all_languages
2344                    .features
2345                    .get_or_insert_default()
2346                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2347            });
2348        });
2349        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2350            EditPredictionStore::try_global(cx).and_then(|store| {
2351                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2352            })
2353        }
2354
2355        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2356            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2357                copilot_ui::initiate_sign_in(copilot, window, cx);
2358            }
2359        });
2360        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2361            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2362                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2363            }
2364        });
2365        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2366            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2367                copilot_ui::initiate_sign_out(copilot, window, cx);
2368            }
2369        });
2370    })
2371    .detach();
2372}