edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use text::Edit;
  40use workspace::Workspace;
  41use zeta_prompt::ZetaPromptInput;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59mod onboarding_modal;
  60pub mod open_ai_response;
  61mod prediction;
  62pub mod sweep_ai;
  63
  64pub mod udiff;
  65
  66mod capture_example;
  67mod zed_edit_prediction_delegate;
  68pub mod zeta1;
  69pub mod zeta2;
  70
  71#[cfg(test)]
  72mod edit_prediction_tests;
  73
  74use crate::capture_example::{
  75    should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
  76};
  77use crate::license_detection::LicenseDetectionWatcher;
  78use crate::mercury::Mercury;
  79use crate::onboarding_modal::ZedPredictModal;
  80pub use crate::prediction::EditPrediction;
  81pub use crate::prediction::EditPredictionId;
  82use crate::prediction::EditPredictionResult;
  83pub use crate::sweep_ai::SweepAi;
  84pub use capture_example::capture_example;
  85pub use language_model::ApiKeyState;
  86pub use telemetry_events::EditPredictionRating;
  87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 105
 106pub struct SweepFeatureFlag;
 107
 108impl FeatureFlag for SweepFeatureFlag {
 109    const NAME: &str = "sweep-ai";
 110}
 111
 112pub struct MercuryFeatureFlag;
 113
 114impl FeatureFlag for MercuryFeatureFlag {
 115    const NAME: &str = "mercury";
 116}
 117
 118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
 119    LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
 120
 121pub struct Zeta2FeatureFlag;
 122
 123impl FeatureFlag for Zeta2FeatureFlag {
 124    const NAME: &'static str = "zeta2";
 125
 126    fn enabled_for_staff() -> bool {
 127        true
 128    }
 129}
 130
 131pub struct EditPredictionExampleCaptureFeatureFlag;
 132
 133impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 134    const NAME: &'static str = "edit-prediction-example-capture";
 135
 136    fn enabled_for_staff() -> bool {
 137        true
 138    }
 139}
 140
 141#[derive(Clone)]
 142struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 143
 144impl Global for EditPredictionStoreGlobal {}
 145
 146pub struct EditPredictionStore {
 147    client: Arc<Client>,
 148    user_store: Entity<UserStore>,
 149    llm_token: LlmApiToken,
 150    _llm_token_subscription: Subscription,
 151    projects: HashMap<EntityId, ProjectState>,
 152    use_context: bool,
 153    update_required: bool,
 154    edit_prediction_model: EditPredictionModel,
 155    pub sweep_ai: SweepAi,
 156    pub mercury: Mercury,
 157    data_collection_choice: DataCollectionChoice,
 158    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 159    shown_predictions: VecDeque<EditPrediction>,
 160    rated_predictions: HashSet<EditPredictionId>,
 161    custom_predict_edits_url: Option<Arc<Url>>,
 162}
 163
 164#[derive(Copy, Clone, Default, PartialEq, Eq)]
 165pub enum EditPredictionModel {
 166    #[default]
 167    Zeta1,
 168    Zeta2 {
 169        version: ZetaVersion,
 170    },
 171    Sweep,
 172    Mercury,
 173}
 174
 175#[derive(Clone)]
 176pub struct EditPredictionModelInput {
 177    project: Entity<Project>,
 178    buffer: Entity<Buffer>,
 179    snapshot: BufferSnapshot,
 180    position: Anchor,
 181    events: Vec<Arc<zeta_prompt::Event>>,
 182    related_files: Vec<RelatedFile>,
 183    recent_paths: VecDeque<ProjectPath>,
 184    trigger: PredictEditsRequestTrigger,
 185    diagnostic_search_range: Range<Point>,
 186    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 187    pub user_actions: Vec<UserActionRecord>,
 188}
 189
 190#[derive(Debug)]
 191pub enum DebugEvent {
 192    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 193    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 194    EditPredictionStarted(EditPredictionStartedDebugEvent),
 195    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 196}
 197
 198#[derive(Debug)]
 199pub struct ContextRetrievalStartedDebugEvent {
 200    pub project_entity_id: EntityId,
 201    pub timestamp: Instant,
 202    pub search_prompt: String,
 203}
 204
 205#[derive(Debug)]
 206pub struct ContextRetrievalFinishedDebugEvent {
 207    pub project_entity_id: EntityId,
 208    pub timestamp: Instant,
 209    pub metadata: Vec<(&'static str, SharedString)>,
 210}
 211
 212#[derive(Debug)]
 213pub struct EditPredictionStartedDebugEvent {
 214    pub buffer: WeakEntity<Buffer>,
 215    pub position: Anchor,
 216    pub prompt: Option<String>,
 217}
 218
 219#[derive(Debug)]
 220pub struct EditPredictionFinishedDebugEvent {
 221    pub buffer: WeakEntity<Buffer>,
 222    pub position: Anchor,
 223    pub model_output: Option<String>,
 224}
 225
 226const USER_ACTION_HISTORY_SIZE: usize = 16;
 227
 228#[derive(Clone, Debug)]
 229pub struct UserActionRecord {
 230    pub action_type: UserActionType,
 231    pub buffer_id: EntityId,
 232    pub line_number: u32,
 233    pub offset: usize,
 234    pub timestamp_epoch_ms: u64,
 235}
 236
 237#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 238pub enum UserActionType {
 239    InsertChar,
 240    InsertSelection,
 241    DeleteChar,
 242    DeleteSelection,
 243    CursorMovement,
 244}
 245
 246/// An event with associated metadata for reconstructing buffer state.
 247#[derive(Clone)]
 248pub struct StoredEvent {
 249    pub event: Arc<zeta_prompt::Event>,
 250    pub old_snapshot: TextBufferSnapshot,
 251}
 252
 253struct ProjectState {
 254    events: VecDeque<StoredEvent>,
 255    last_event: Option<LastEvent>,
 256    recent_paths: VecDeque<ProjectPath>,
 257    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 258    current_prediction: Option<CurrentEditPrediction>,
 259    next_pending_prediction_id: usize,
 260    pending_predictions: ArrayVec<PendingPrediction, 2>,
 261    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 262    last_prediction_refresh: Option<(EntityId, Instant)>,
 263    cancelled_predictions: HashSet<usize>,
 264    context: Entity<RelatedExcerptStore>,
 265    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 266    user_actions: VecDeque<UserActionRecord>,
 267    _subscription: gpui::Subscription,
 268    copilot: Option<Entity<Copilot>>,
 269}
 270
 271impl ProjectState {
 272    fn record_user_action(&mut self, action: UserActionRecord) {
 273        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 274            self.user_actions.pop_front();
 275        }
 276        self.user_actions.push_back(action);
 277    }
 278
 279    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 280        self.events
 281            .iter()
 282            .cloned()
 283            .chain(
 284                self.last_event
 285                    .as_ref()
 286                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 287            )
 288            .collect()
 289    }
 290
 291    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 292        self.events
 293            .iter()
 294            .cloned()
 295            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 296                let (one, two) = event.split_by_pause();
 297                let one = one.finalize(&self.license_detection_watchers, cx);
 298                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 299                one.into_iter().chain(two)
 300            }))
 301            .collect()
 302    }
 303
 304    fn cancel_pending_prediction(
 305        &mut self,
 306        pending_prediction: PendingPrediction,
 307        cx: &mut Context<EditPredictionStore>,
 308    ) {
 309        self.cancelled_predictions.insert(pending_prediction.id);
 310
 311        cx.spawn(async move |this, cx| {
 312            let Some(prediction_id) = pending_prediction.task.await else {
 313                return;
 314            };
 315
 316            this.update(cx, |this, _cx| {
 317                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 318            })
 319            .ok();
 320        })
 321        .detach()
 322    }
 323
 324    fn active_buffer(
 325        &self,
 326        project: &Entity<Project>,
 327        cx: &App,
 328    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 329        let project = project.read(cx);
 330        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 331        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 332        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 333        Some((active_buffer, registered_buffer.last_position))
 334    }
 335}
 336
 337#[derive(Debug, Clone)]
 338struct CurrentEditPrediction {
 339    pub requested_by: PredictionRequestedBy,
 340    pub prediction: EditPrediction,
 341    pub was_shown: bool,
 342    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 343}
 344
 345impl CurrentEditPrediction {
 346    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 347        let Some(new_edits) = self
 348            .prediction
 349            .interpolate(&self.prediction.buffer.read(cx))
 350        else {
 351            return false;
 352        };
 353
 354        if self.prediction.buffer != old_prediction.prediction.buffer {
 355            return true;
 356        }
 357
 358        let Some(old_edits) = old_prediction
 359            .prediction
 360            .interpolate(&old_prediction.prediction.buffer.read(cx))
 361        else {
 362            return true;
 363        };
 364
 365        let requested_by_buffer_id = self.requested_by.buffer_id();
 366
 367        // This reduces the occurrence of UI thrash from replacing edits
 368        //
 369        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 370        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 371            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 372            && old_edits.len() == 1
 373            && new_edits.len() == 1
 374        {
 375            let (old_range, old_text) = &old_edits[0];
 376            let (new_range, new_text) = &new_edits[0];
 377            new_range == old_range && new_text.starts_with(old_text.as_ref())
 378        } else {
 379            true
 380        }
 381    }
 382}
 383
 384#[derive(Debug, Clone)]
 385enum PredictionRequestedBy {
 386    DiagnosticsUpdate,
 387    Buffer(EntityId),
 388}
 389
 390impl PredictionRequestedBy {
 391    pub fn buffer_id(&self) -> Option<EntityId> {
 392        match self {
 393            PredictionRequestedBy::DiagnosticsUpdate => None,
 394            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 395        }
 396    }
 397}
 398
 399#[derive(Debug)]
 400struct PendingPrediction {
 401    id: usize,
 402    task: Task<Option<EditPredictionId>>,
 403}
 404
 405/// A prediction from the perspective of a buffer.
 406#[derive(Debug)]
 407enum BufferEditPrediction<'a> {
 408    Local { prediction: &'a EditPrediction },
 409    Jump { prediction: &'a EditPrediction },
 410}
 411
 412#[cfg(test)]
 413impl std::ops::Deref for BufferEditPrediction<'_> {
 414    type Target = EditPrediction;
 415
 416    fn deref(&self) -> &Self::Target {
 417        match self {
 418            BufferEditPrediction::Local { prediction } => prediction,
 419            BufferEditPrediction::Jump { prediction } => prediction,
 420        }
 421    }
 422}
 423
 424struct RegisteredBuffer {
 425    file: Option<Arc<dyn File>>,
 426    snapshot: TextBufferSnapshot,
 427    last_position: Option<Anchor>,
 428    _subscriptions: [gpui::Subscription; 2],
 429}
 430
 431#[derive(Clone)]
 432struct LastEvent {
 433    old_snapshot: TextBufferSnapshot,
 434    new_snapshot: TextBufferSnapshot,
 435    old_file: Option<Arc<dyn File>>,
 436    new_file: Option<Arc<dyn File>>,
 437    edit_range: Option<Range<Anchor>>,
 438    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 439    last_edit_time: Option<Instant>,
 440}
 441
 442impl LastEvent {
 443    pub fn finalize(
 444        &self,
 445        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 446        cx: &App,
 447    ) -> Option<StoredEvent> {
 448        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 449        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 450
 451        let in_open_source_repo =
 452            [self.new_file.as_ref(), self.old_file.as_ref()]
 453                .iter()
 454                .all(|file| {
 455                    file.is_some_and(|file| {
 456                        license_detection_watchers
 457                            .get(&file.worktree_id(cx))
 458                            .is_some_and(|watcher| watcher.is_project_open_source())
 459                    })
 460                });
 461
 462        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 463
 464        if path == old_path && diff.is_empty() {
 465            None
 466        } else {
 467            Some(StoredEvent {
 468                event: Arc::new(zeta_prompt::Event::BufferChange {
 469                    old_path,
 470                    path,
 471                    diff,
 472                    in_open_source_repo,
 473                    // TODO: Actually detect if this edit was predicted or not
 474                    predicted: false,
 475                }),
 476                old_snapshot: self.old_snapshot.clone(),
 477            })
 478        }
 479    }
 480
 481    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 482        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 483            return (self.clone(), None);
 484        };
 485
 486        let before = LastEvent {
 487            old_snapshot: self.old_snapshot.clone(),
 488            new_snapshot: boundary_snapshot.clone(),
 489            old_file: self.old_file.clone(),
 490            new_file: self.new_file.clone(),
 491            edit_range: None,
 492            snapshot_after_last_editing_pause: None,
 493            last_edit_time: self.last_edit_time,
 494        };
 495
 496        let after = LastEvent {
 497            old_snapshot: boundary_snapshot.clone(),
 498            new_snapshot: self.new_snapshot.clone(),
 499            old_file: self.old_file.clone(),
 500            new_file: self.new_file.clone(),
 501            edit_range: None,
 502            snapshot_after_last_editing_pause: None,
 503            last_edit_time: self.last_edit_time,
 504        };
 505
 506        (before, Some(after))
 507    }
 508}
 509
 510pub(crate) fn compute_diff_between_snapshots(
 511    old_snapshot: &TextBufferSnapshot,
 512    new_snapshot: &TextBufferSnapshot,
 513) -> Option<String> {
 514    let edits: Vec<Edit<usize>> = new_snapshot
 515        .edits_since::<usize>(&old_snapshot.version)
 516        .collect();
 517
 518    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 519
 520    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 521    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 522    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 523    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 524
 525    const CONTEXT_LINES: u32 = 3;
 526
 527    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 528    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 529    let old_context_end_row =
 530        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 531    let new_context_end_row =
 532        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 533
 534    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 535    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 536    let old_end_line_offset = old_snapshot
 537        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 538    let new_end_line_offset = new_snapshot
 539        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 540    let old_edit_range = old_start_line_offset..old_end_line_offset;
 541    let new_edit_range = new_start_line_offset..new_end_line_offset;
 542
 543    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 544    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 545
 546    let diff = language::unified_diff_with_offsets(
 547        &old_region_text,
 548        &new_region_text,
 549        old_context_start_row,
 550        new_context_start_row,
 551    );
 552
 553    Some(diff)
 554}
 555
 556fn buffer_path_with_id_fallback(
 557    file: Option<&Arc<dyn File>>,
 558    snapshot: &TextBufferSnapshot,
 559    cx: &App,
 560) -> Arc<Path> {
 561    if let Some(file) = file {
 562        file.full_path(cx).into()
 563    } else {
 564        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 565    }
 566}
 567
 568impl EditPredictionStore {
 569    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 570        cx.try_global::<EditPredictionStoreGlobal>()
 571            .map(|global| global.0.clone())
 572    }
 573
 574    pub fn global(
 575        client: &Arc<Client>,
 576        user_store: &Entity<UserStore>,
 577        cx: &mut App,
 578    ) -> Entity<Self> {
 579        cx.try_global::<EditPredictionStoreGlobal>()
 580            .map(|global| global.0.clone())
 581            .unwrap_or_else(|| {
 582                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 583                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 584                ep_store
 585            })
 586    }
 587
 588    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 589        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 590        let data_collection_choice = Self::load_data_collection_choice();
 591
 592        let llm_token = LlmApiToken::default();
 593
 594        let (reject_tx, reject_rx) = mpsc::unbounded();
 595        cx.background_spawn({
 596            let client = client.clone();
 597            let llm_token = llm_token.clone();
 598            let app_version = AppVersion::global(cx);
 599            let background_executor = cx.background_executor().clone();
 600            async move {
 601                Self::handle_rejected_predictions(
 602                    reject_rx,
 603                    client,
 604                    llm_token,
 605                    app_version,
 606                    background_executor,
 607                )
 608                .await
 609            }
 610        })
 611        .detach();
 612
 613        let mut this = Self {
 614            projects: HashMap::default(),
 615            client,
 616            user_store,
 617            use_context: false,
 618            llm_token,
 619            _llm_token_subscription: cx.subscribe(
 620                &refresh_llm_token_listener,
 621                |this, _listener, _event, cx| {
 622                    let client = this.client.clone();
 623                    let llm_token = this.llm_token.clone();
 624                    cx.spawn(async move |_this, _cx| {
 625                        llm_token.refresh(&client).await?;
 626                        anyhow::Ok(())
 627                    })
 628                    .detach_and_log_err(cx);
 629                },
 630            ),
 631            update_required: false,
 632            edit_prediction_model: EditPredictionModel::Zeta2 {
 633                version: Default::default(),
 634            },
 635            sweep_ai: SweepAi::new(cx),
 636            mercury: Mercury::new(cx),
 637
 638            data_collection_choice,
 639            reject_predictions_tx: reject_tx,
 640            rated_predictions: Default::default(),
 641            shown_predictions: Default::default(),
 642            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 643                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 644                Err(_) => None,
 645            },
 646        };
 647
 648        this.configure_context_retrieval(cx);
 649        let weak_this = cx.weak_entity();
 650        cx.on_flags_ready(move |_, cx| {
 651            weak_this
 652                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 653                .ok();
 654        })
 655        .detach();
 656        cx.observe_global::<SettingsStore>(|this, cx| {
 657            this.configure_context_retrieval(cx);
 658        })
 659        .detach();
 660
 661        this
 662    }
 663
 664    #[cfg(test)]
 665    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 666        self.custom_predict_edits_url = Some(url.into());
 667    }
 668
 669    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 670        self.edit_prediction_model = model;
 671    }
 672
 673    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 674        self.sweep_ai.api_token.read(cx).has_key()
 675    }
 676
 677    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 678        self.mercury.api_token.read(cx).has_key()
 679    }
 680
 681    pub fn set_use_context(&mut self, use_context: bool) {
 682        self.use_context = use_context;
 683    }
 684
 685    pub fn clear_history(&mut self) {
 686        for project_state in self.projects.values_mut() {
 687            project_state.events.clear();
 688            project_state.last_event.take();
 689        }
 690    }
 691
 692    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 693        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 694            project_state.events.clear();
 695            project_state.last_event.take();
 696        }
 697    }
 698
 699    pub fn edit_history_for_project(
 700        &self,
 701        project: &Entity<Project>,
 702        cx: &App,
 703    ) -> Vec<StoredEvent> {
 704        self.projects
 705            .get(&project.entity_id())
 706            .map(|project_state| project_state.events(cx))
 707            .unwrap_or_default()
 708    }
 709
 710    pub fn edit_history_for_project_with_pause_split_last_event(
 711        &self,
 712        project: &Entity<Project>,
 713        cx: &App,
 714    ) -> Vec<StoredEvent> {
 715        self.projects
 716            .get(&project.entity_id())
 717            .map(|project_state| project_state.events_split_by_pause(cx))
 718            .unwrap_or_default()
 719    }
 720
 721    pub fn context_for_project<'a>(
 722        &'a self,
 723        project: &Entity<Project>,
 724        cx: &'a mut App,
 725    ) -> Vec<RelatedFile> {
 726        self.projects
 727            .get(&project.entity_id())
 728            .map(|project| {
 729                project
 730                    .context
 731                    .update(cx, |context, cx| context.related_files(cx))
 732            })
 733            .unwrap_or_default()
 734    }
 735
 736    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 737        self.projects
 738            .get(&project.entity_id())
 739            .and_then(|project| project.copilot.clone())
 740    }
 741
 742    pub fn start_copilot_for_project(
 743        &mut self,
 744        project: &Entity<Project>,
 745        cx: &mut Context<Self>,
 746    ) -> Option<Entity<Copilot>> {
 747        let state = self.get_or_init_project(project, cx);
 748
 749        if state.copilot.is_some() {
 750            return state.copilot.clone();
 751        }
 752        let _project = project.clone();
 753        let project = project.read(cx);
 754
 755        let node = project.node_runtime().cloned();
 756        if let Some(node) = node {
 757            let next_id = project.languages().next_language_server_id();
 758            let fs = project.fs().clone();
 759
 760            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 761            state.copilot = Some(copilot.clone());
 762            Some(copilot)
 763        } else {
 764            None
 765        }
 766    }
 767
 768    pub fn context_for_project_with_buffers<'a>(
 769        &'a self,
 770        project: &Entity<Project>,
 771        cx: &'a mut App,
 772    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 773        self.projects
 774            .get(&project.entity_id())
 775            .map(|project| {
 776                project
 777                    .context
 778                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 779            })
 780            .unwrap_or_default()
 781    }
 782
 783    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 784        if matches!(
 785            self.edit_prediction_model,
 786            EditPredictionModel::Zeta2 { .. }
 787        ) {
 788            self.user_store.read(cx).edit_prediction_usage()
 789        } else {
 790            None
 791        }
 792    }
 793
 794    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 795        self.get_or_init_project(project, cx);
 796    }
 797
 798    pub fn register_buffer(
 799        &mut self,
 800        buffer: &Entity<Buffer>,
 801        project: &Entity<Project>,
 802        cx: &mut Context<Self>,
 803    ) {
 804        let project_state = self.get_or_init_project(project, cx);
 805        Self::register_buffer_impl(project_state, buffer, project, cx);
 806    }
 807
 808    fn get_or_init_project(
 809        &mut self,
 810        project: &Entity<Project>,
 811        cx: &mut Context<Self>,
 812    ) -> &mut ProjectState {
 813        let entity_id = project.entity_id();
 814        self.projects
 815            .entry(entity_id)
 816            .or_insert_with(|| ProjectState {
 817                context: {
 818                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 819                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 820                        this.handle_excerpt_store_event(entity_id, event);
 821                    })
 822                    .detach();
 823                    related_excerpt_store
 824                },
 825                events: VecDeque::new(),
 826                last_event: None,
 827                recent_paths: VecDeque::new(),
 828                debug_tx: None,
 829                registered_buffers: HashMap::default(),
 830                current_prediction: None,
 831                cancelled_predictions: HashSet::default(),
 832                pending_predictions: ArrayVec::new(),
 833                next_pending_prediction_id: 0,
 834                last_prediction_refresh: None,
 835                license_detection_watchers: HashMap::default(),
 836                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 837                _subscription: cx.subscribe(&project, Self::handle_project_event),
 838                copilot: None,
 839            })
 840    }
 841
 842    pub fn remove_project(&mut self, project: &Entity<Project>) {
 843        self.projects.remove(&project.entity_id());
 844    }
 845
 846    fn handle_excerpt_store_event(
 847        &mut self,
 848        project_entity_id: EntityId,
 849        event: &RelatedExcerptStoreEvent,
 850    ) {
 851        if let Some(project_state) = self.projects.get(&project_entity_id) {
 852            if let Some(debug_tx) = project_state.debug_tx.clone() {
 853                match event {
 854                    RelatedExcerptStoreEvent::StartedRefresh => {
 855                        debug_tx
 856                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 857                                ContextRetrievalStartedDebugEvent {
 858                                    project_entity_id: project_entity_id,
 859                                    timestamp: Instant::now(),
 860                                    search_prompt: String::new(),
 861                                },
 862                            ))
 863                            .ok();
 864                    }
 865                    RelatedExcerptStoreEvent::FinishedRefresh {
 866                        cache_hit_count,
 867                        cache_miss_count,
 868                        mean_definition_latency,
 869                        max_definition_latency,
 870                    } => {
 871                        debug_tx
 872                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 873                                ContextRetrievalFinishedDebugEvent {
 874                                    project_entity_id: project_entity_id,
 875                                    timestamp: Instant::now(),
 876                                    metadata: vec![
 877                                        (
 878                                            "Cache Hits",
 879                                            format!(
 880                                                "{}/{}",
 881                                                cache_hit_count,
 882                                                cache_hit_count + cache_miss_count
 883                                            )
 884                                            .into(),
 885                                        ),
 886                                        (
 887                                            "Max LSP Time",
 888                                            format!("{} ms", max_definition_latency.as_millis())
 889                                                .into(),
 890                                        ),
 891                                        (
 892                                            "Mean LSP Time",
 893                                            format!("{} ms", mean_definition_latency.as_millis())
 894                                                .into(),
 895                                        ),
 896                                    ],
 897                                },
 898                            ))
 899                            .ok();
 900                    }
 901                }
 902            }
 903        }
 904    }
 905
 906    pub fn debug_info(
 907        &mut self,
 908        project: &Entity<Project>,
 909        cx: &mut Context<Self>,
 910    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 911        let project_state = self.get_or_init_project(project, cx);
 912        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 913        project_state.debug_tx = Some(debug_watch_tx);
 914        debug_watch_rx
 915    }
 916
 917    fn handle_project_event(
 918        &mut self,
 919        project: Entity<Project>,
 920        event: &project::Event,
 921        cx: &mut Context<Self>,
 922    ) {
 923        // TODO [zeta2] init with recent paths
 924        match event {
 925            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 926                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 927                    return;
 928                };
 929                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 930                if let Some(path) = path {
 931                    if let Some(ix) = project_state
 932                        .recent_paths
 933                        .iter()
 934                        .position(|probe| probe == &path)
 935                    {
 936                        project_state.recent_paths.remove(ix);
 937                    }
 938                    project_state.recent_paths.push_front(path);
 939                }
 940            }
 941            project::Event::DiagnosticsUpdated { .. } => {
 942                if cx.has_flag::<Zeta2FeatureFlag>() {
 943                    self.refresh_prediction_from_diagnostics(project, cx);
 944                }
 945            }
 946            _ => (),
 947        }
 948    }
 949
 950    fn register_buffer_impl<'a>(
 951        project_state: &'a mut ProjectState,
 952        buffer: &Entity<Buffer>,
 953        project: &Entity<Project>,
 954        cx: &mut Context<Self>,
 955    ) -> &'a mut RegisteredBuffer {
 956        let buffer_id = buffer.entity_id();
 957
 958        if let Some(file) = buffer.read(cx).file() {
 959            let worktree_id = file.worktree_id(cx);
 960            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 961                project_state
 962                    .license_detection_watchers
 963                    .entry(worktree_id)
 964                    .or_insert_with(|| {
 965                        let project_entity_id = project.entity_id();
 966                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 967                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 968                            else {
 969                                return;
 970                            };
 971                            project_state
 972                                .license_detection_watchers
 973                                .remove(&worktree_id);
 974                        })
 975                        .detach();
 976                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 977                    });
 978            }
 979        }
 980
 981        match project_state.registered_buffers.entry(buffer_id) {
 982            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 983            hash_map::Entry::Vacant(entry) => {
 984                let buf = buffer.read(cx);
 985                let snapshot = buf.text_snapshot();
 986                let file = buf.file().cloned();
 987                let project_entity_id = project.entity_id();
 988                entry.insert(RegisteredBuffer {
 989                    snapshot,
 990                    file,
 991                    last_position: None,
 992                    _subscriptions: [
 993                        cx.subscribe(buffer, {
 994                            let project = project.downgrade();
 995                            move |this, buffer, event, cx| {
 996                                if let language::BufferEvent::Edited = event
 997                                    && let Some(project) = project.upgrade()
 998                                {
 999                                    this.report_changes_for_buffer(&buffer, &project, cx);
1000                                }
1001                            }
1002                        }),
1003                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1004                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1005                            else {
1006                                return;
1007                            };
1008                            project_state.registered_buffers.remove(&buffer_id);
1009                        }),
1010                    ],
1011                })
1012            }
1013        }
1014    }
1015
1016    fn report_changes_for_buffer(
1017        &mut self,
1018        buffer: &Entity<Buffer>,
1019        project: &Entity<Project>,
1020        cx: &mut Context<Self>,
1021    ) {
1022        let project_state = self.get_or_init_project(project, cx);
1023        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1024
1025        let buf = buffer.read(cx);
1026        let new_file = buf.file().cloned();
1027        let new_snapshot = buf.text_snapshot();
1028        if new_snapshot.version == registered_buffer.snapshot.version {
1029            return;
1030        }
1031
1032        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1033        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1034        let mut num_edits = 0usize;
1035        let mut total_deleted = 0usize;
1036        let mut total_inserted = 0usize;
1037        let mut edit_range: Option<Range<Anchor>> = None;
1038        let mut last_offset: Option<usize> = None;
1039
1040        for (edit, anchor_range) in
1041            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1042        {
1043            num_edits += 1;
1044            total_deleted += edit.old.len();
1045            total_inserted += edit.new.len();
1046            edit_range = Some(match edit_range {
1047                None => anchor_range,
1048                Some(acc) => acc.start..anchor_range.end,
1049            });
1050            last_offset = Some(edit.new.end);
1051        }
1052
1053        if num_edits > 0 {
1054            let action_type = match (total_deleted, total_inserted, num_edits) {
1055                (0, ins, n) if ins == n => UserActionType::InsertChar,
1056                (0, _, _) => UserActionType::InsertSelection,
1057                (del, 0, n) if del == n => UserActionType::DeleteChar,
1058                (_, 0, _) => UserActionType::DeleteSelection,
1059                (_, ins, n) if ins == n => UserActionType::InsertChar,
1060                (_, _, _) => UserActionType::InsertSelection,
1061            };
1062
1063            if let Some(offset) = last_offset {
1064                let point = new_snapshot.offset_to_point(offset);
1065                let timestamp_epoch_ms = SystemTime::now()
1066                    .duration_since(UNIX_EPOCH)
1067                    .map(|d| d.as_millis() as u64)
1068                    .unwrap_or(0);
1069                project_state.record_user_action(UserActionRecord {
1070                    action_type,
1071                    buffer_id: buffer.entity_id(),
1072                    line_number: point.row,
1073                    offset,
1074                    timestamp_epoch_ms,
1075                });
1076            }
1077        }
1078
1079        let events = &mut project_state.events;
1080
1081        let now = cx.background_executor().now();
1082        if let Some(last_event) = project_state.last_event.as_mut() {
1083            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1084                == last_event.new_snapshot.remote_id()
1085                && old_snapshot.version == last_event.new_snapshot.version;
1086
1087            let should_coalesce = is_next_snapshot_of_same_buffer
1088                && edit_range
1089                    .as_ref()
1090                    .zip(last_event.edit_range.as_ref())
1091                    .is_some_and(|(a, b)| {
1092                        let a = a.to_point(&new_snapshot);
1093                        let b = b.to_point(&new_snapshot);
1094                        if a.start > b.end {
1095                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1096                        } else if b.start > a.end {
1097                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1098                        } else {
1099                            true
1100                        }
1101                    });
1102
1103            if should_coalesce {
1104                let pause_elapsed = last_event
1105                    .last_edit_time
1106                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1107                    .unwrap_or(false);
1108                if pause_elapsed {
1109                    last_event.snapshot_after_last_editing_pause =
1110                        Some(last_event.new_snapshot.clone());
1111                }
1112
1113                last_event.edit_range = edit_range;
1114                last_event.new_snapshot = new_snapshot;
1115                last_event.last_edit_time = Some(now);
1116                return;
1117            }
1118        }
1119
1120        if let Some(event) = project_state.last_event.take() {
1121            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1122                if events.len() + 1 >= EVENT_COUNT_MAX {
1123                    events.pop_front();
1124                }
1125                events.push_back(event);
1126            }
1127        }
1128
1129        project_state.last_event = Some(LastEvent {
1130            old_file,
1131            new_file,
1132            old_snapshot,
1133            new_snapshot,
1134            edit_range,
1135            snapshot_after_last_editing_pause: None,
1136            last_edit_time: Some(now),
1137        });
1138    }
1139
1140    fn prediction_at(
1141        &mut self,
1142        buffer: &Entity<Buffer>,
1143        position: Option<language::Anchor>,
1144        project: &Entity<Project>,
1145        cx: &App,
1146    ) -> Option<BufferEditPrediction<'_>> {
1147        let project_state = self.projects.get_mut(&project.entity_id())?;
1148        if let Some(position) = position
1149            && let Some(buffer) = project_state
1150                .registered_buffers
1151                .get_mut(&buffer.entity_id())
1152        {
1153            buffer.last_position = Some(position);
1154        }
1155
1156        let CurrentEditPrediction {
1157            requested_by,
1158            prediction,
1159            ..
1160        } = project_state.current_prediction.as_ref()?;
1161
1162        if prediction.targets_buffer(buffer.read(cx)) {
1163            Some(BufferEditPrediction::Local { prediction })
1164        } else {
1165            let show_jump = match requested_by {
1166                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1167                    requested_by_buffer_id == &buffer.entity_id()
1168                }
1169                PredictionRequestedBy::DiagnosticsUpdate => true,
1170            };
1171
1172            if show_jump {
1173                Some(BufferEditPrediction::Jump { prediction })
1174            } else {
1175                None
1176            }
1177        }
1178    }
1179
1180    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1181        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1182            return;
1183        };
1184
1185        let Some(current_prediction) = project_state.current_prediction.take() else {
1186            return;
1187        };
1188
1189        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1190            project_state.cancel_pending_prediction(pending_prediction, cx);
1191        }
1192
1193        match self.edit_prediction_model {
1194            EditPredictionModel::Sweep => {
1195                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1196            }
1197            EditPredictionModel::Mercury => {}
1198            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1199                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1200            }
1201        }
1202    }
1203
1204    async fn handle_rejected_predictions(
1205        rx: UnboundedReceiver<EditPredictionRejection>,
1206        client: Arc<Client>,
1207        llm_token: LlmApiToken,
1208        app_version: Version,
1209        background_executor: BackgroundExecutor,
1210    ) {
1211        let mut rx = std::pin::pin!(rx.peekable());
1212        let mut batched = Vec::new();
1213
1214        while let Some(rejection) = rx.next().await {
1215            batched.push(rejection);
1216
1217            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1218                select_biased! {
1219                    next = rx.as_mut().peek().fuse() => {
1220                        if next.is_some() {
1221                            continue;
1222                        }
1223                    }
1224                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1225                }
1226            }
1227
1228            let url = client
1229                .http_client()
1230                .build_zed_llm_url("/predict_edits/reject", &[])
1231                .unwrap();
1232
1233            let flush_count = batched
1234                .len()
1235                // in case items have accumulated after failure
1236                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1237            let start = batched.len() - flush_count;
1238
1239            let body = RejectEditPredictionsBodyRef {
1240                rejections: &batched[start..],
1241            };
1242
1243            let result = Self::send_api_request::<()>(
1244                |builder| {
1245                    let req = builder
1246                        .uri(url.as_ref())
1247                        .body(serde_json::to_string(&body)?.into());
1248                    anyhow::Ok(req?)
1249                },
1250                client.clone(),
1251                llm_token.clone(),
1252                app_version.clone(),
1253                true,
1254            )
1255            .await;
1256
1257            if result.log_err().is_some() {
1258                batched.drain(start..);
1259            }
1260        }
1261    }
1262
1263    fn reject_current_prediction(
1264        &mut self,
1265        reason: EditPredictionRejectReason,
1266        project: &Entity<Project>,
1267    ) {
1268        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1269            project_state.pending_predictions.clear();
1270            if let Some(prediction) = project_state.current_prediction.take() {
1271                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1272            }
1273        };
1274    }
1275
1276    fn did_show_current_prediction(
1277        &mut self,
1278        project: &Entity<Project>,
1279        display_type: edit_prediction_types::SuggestionDisplayType,
1280        cx: &mut Context<Self>,
1281    ) {
1282        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1283            return;
1284        };
1285
1286        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1287            return;
1288        };
1289
1290        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1291        let previous_shown_with = current_prediction.shown_with;
1292
1293        if previous_shown_with.is_none() || !is_jump {
1294            current_prediction.shown_with = Some(display_type);
1295        }
1296
1297        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1298
1299        if is_first_non_jump_show {
1300            current_prediction.was_shown = true;
1301        }
1302
1303        let display_type_changed = previous_shown_with != Some(display_type);
1304
1305        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1306            sweep_ai::edit_prediction_shown(
1307                &self.sweep_ai,
1308                self.client.clone(),
1309                &current_prediction.prediction,
1310                display_type,
1311                cx,
1312            );
1313        }
1314
1315        if is_first_non_jump_show {
1316            self.shown_predictions
1317                .push_front(current_prediction.prediction.clone());
1318            if self.shown_predictions.len() > 50 {
1319                let completion = self.shown_predictions.pop_back().unwrap();
1320                self.rated_predictions.remove(&completion.id);
1321            }
1322        }
1323    }
1324
1325    fn reject_prediction(
1326        &mut self,
1327        prediction_id: EditPredictionId,
1328        reason: EditPredictionRejectReason,
1329        was_shown: bool,
1330    ) {
1331        match self.edit_prediction_model {
1332            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1333                if self.custom_predict_edits_url.is_some() {
1334                    return;
1335                }
1336            }
1337            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1338        }
1339
1340        self.reject_predictions_tx
1341            .unbounded_send(EditPredictionRejection {
1342                request_id: prediction_id.to_string(),
1343                reason,
1344                was_shown,
1345            })
1346            .log_err();
1347    }
1348
1349    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1350        self.projects
1351            .get(&project.entity_id())
1352            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1353    }
1354
1355    pub fn refresh_prediction_from_buffer(
1356        &mut self,
1357        project: Entity<Project>,
1358        buffer: Entity<Buffer>,
1359        position: language::Anchor,
1360        cx: &mut Context<Self>,
1361    ) {
1362        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1363            let Some(request_task) = this
1364                .update(cx, |this, cx| {
1365                    this.request_prediction(
1366                        &project,
1367                        &buffer,
1368                        position,
1369                        PredictEditsRequestTrigger::Other,
1370                        cx,
1371                    )
1372                })
1373                .log_err()
1374            else {
1375                return Task::ready(anyhow::Ok(None));
1376            };
1377
1378            cx.spawn(async move |_cx| {
1379                request_task.await.map(|prediction_result| {
1380                    prediction_result.map(|prediction_result| {
1381                        (
1382                            prediction_result,
1383                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1384                        )
1385                    })
1386                })
1387            })
1388        })
1389    }
1390
1391    pub fn refresh_prediction_from_diagnostics(
1392        &mut self,
1393        project: Entity<Project>,
1394        cx: &mut Context<Self>,
1395    ) {
1396        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1397            return;
1398        };
1399
1400        // Prefer predictions from buffer
1401        if project_state.current_prediction.is_some() {
1402            return;
1403        };
1404
1405        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1406            let Some((active_buffer, snapshot, cursor_point)) = this
1407                .read_with(cx, |this, cx| {
1408                    let project_state = this.projects.get(&project.entity_id())?;
1409                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1410                    let snapshot = buffer.read(cx).snapshot();
1411
1412                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1413                        return None;
1414                    }
1415
1416                    let cursor_point = position
1417                        .map(|pos| pos.to_point(&snapshot))
1418                        .unwrap_or_default();
1419
1420                    Some((buffer, snapshot, cursor_point))
1421                })
1422                .log_err()
1423                .flatten()
1424            else {
1425                return Task::ready(anyhow::Ok(None));
1426            };
1427
1428            cx.spawn(async move |cx| {
1429                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1430                    active_buffer,
1431                    &snapshot,
1432                    Default::default(),
1433                    cursor_point,
1434                    &project,
1435                    cx,
1436                )
1437                .await?
1438                else {
1439                    return anyhow::Ok(None);
1440                };
1441
1442                let Some(prediction_result) = this
1443                    .update(cx, |this, cx| {
1444                        this.request_prediction(
1445                            &project,
1446                            &jump_buffer,
1447                            jump_position,
1448                            PredictEditsRequestTrigger::Diagnostics,
1449                            cx,
1450                        )
1451                    })?
1452                    .await?
1453                else {
1454                    return anyhow::Ok(None);
1455                };
1456
1457                this.update(cx, |this, cx| {
1458                    Some((
1459                        if this
1460                            .get_or_init_project(&project, cx)
1461                            .current_prediction
1462                            .is_none()
1463                        {
1464                            prediction_result
1465                        } else {
1466                            EditPredictionResult {
1467                                id: prediction_result.id,
1468                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1469                            }
1470                        },
1471                        PredictionRequestedBy::DiagnosticsUpdate,
1472                    ))
1473                })
1474            })
1475        });
1476    }
1477
1478    fn predictions_enabled_at(
1479        snapshot: &BufferSnapshot,
1480        position: Option<language::Anchor>,
1481        cx: &App,
1482    ) -> bool {
1483        let file = snapshot.file();
1484        let all_settings = all_language_settings(file, cx);
1485        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1486            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1487        {
1488            return false;
1489        }
1490
1491        if let Some(last_position) = position {
1492            let settings = snapshot.settings_at(last_position, cx);
1493
1494            if !settings.edit_predictions_disabled_in.is_empty()
1495                && let Some(scope) = snapshot.language_scope_at(last_position)
1496                && let Some(scope_name) = scope.override_name()
1497                && settings
1498                    .edit_predictions_disabled_in
1499                    .iter()
1500                    .any(|s| s == scope_name)
1501            {
1502                return false;
1503            }
1504        }
1505
1506        true
1507    }
1508
1509    #[cfg(not(test))]
1510    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1511    #[cfg(test)]
1512    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1513
1514    fn queue_prediction_refresh(
1515        &mut self,
1516        project: Entity<Project>,
1517        throttle_entity: EntityId,
1518        cx: &mut Context<Self>,
1519        do_refresh: impl FnOnce(
1520            WeakEntity<Self>,
1521            &mut AsyncApp,
1522        )
1523            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1524        + 'static,
1525    ) {
1526        let project_state = self.get_or_init_project(&project, cx);
1527        let pending_prediction_id = project_state.next_pending_prediction_id;
1528        project_state.next_pending_prediction_id += 1;
1529        let last_request = project_state.last_prediction_refresh;
1530
1531        let task = cx.spawn(async move |this, cx| {
1532            if let Some((last_entity, last_timestamp)) = last_request
1533                && throttle_entity == last_entity
1534                && let Some(timeout) =
1535                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1536            {
1537                cx.background_executor().timer(timeout).await;
1538            }
1539
1540            // If this task was cancelled before the throttle timeout expired,
1541            // do not perform a request.
1542            let mut is_cancelled = true;
1543            this.update(cx, |this, cx| {
1544                let project_state = this.get_or_init_project(&project, cx);
1545                if !project_state
1546                    .cancelled_predictions
1547                    .remove(&pending_prediction_id)
1548                {
1549                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1550                    is_cancelled = false;
1551                }
1552            })
1553            .ok();
1554            if is_cancelled {
1555                return None;
1556            }
1557
1558            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1559            let new_prediction_id = new_prediction_result
1560                .as_ref()
1561                .map(|(prediction, _)| prediction.id.clone());
1562
1563            // When a prediction completes, remove it from the pending list, and cancel
1564            // any pending predictions that were enqueued before it.
1565            this.update(cx, |this, cx| {
1566                let project_state = this.get_or_init_project(&project, cx);
1567
1568                let is_cancelled = project_state
1569                    .cancelled_predictions
1570                    .remove(&pending_prediction_id);
1571
1572                let new_current_prediction = if !is_cancelled
1573                    && let Some((prediction_result, requested_by)) = new_prediction_result
1574                {
1575                    match prediction_result.prediction {
1576                        Ok(prediction) => {
1577                            let new_prediction = CurrentEditPrediction {
1578                                requested_by,
1579                                prediction,
1580                                was_shown: false,
1581                                shown_with: None,
1582                            };
1583
1584                            if let Some(current_prediction) =
1585                                project_state.current_prediction.as_ref()
1586                            {
1587                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1588                                {
1589                                    this.reject_current_prediction(
1590                                        EditPredictionRejectReason::Replaced,
1591                                        &project,
1592                                    );
1593
1594                                    Some(new_prediction)
1595                                } else {
1596                                    this.reject_prediction(
1597                                        new_prediction.prediction.id,
1598                                        EditPredictionRejectReason::CurrentPreferred,
1599                                        false,
1600                                    );
1601                                    None
1602                                }
1603                            } else {
1604                                Some(new_prediction)
1605                            }
1606                        }
1607                        Err(reject_reason) => {
1608                            this.reject_prediction(prediction_result.id, reject_reason, false);
1609                            None
1610                        }
1611                    }
1612                } else {
1613                    None
1614                };
1615
1616                let project_state = this.get_or_init_project(&project, cx);
1617
1618                if let Some(new_prediction) = new_current_prediction {
1619                    project_state.current_prediction = Some(new_prediction);
1620                }
1621
1622                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1623                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1624                    if pending_prediction.id == pending_prediction_id {
1625                        pending_predictions.remove(ix);
1626                        for pending_prediction in pending_predictions.drain(0..ix) {
1627                            project_state.cancel_pending_prediction(pending_prediction, cx)
1628                        }
1629                        break;
1630                    }
1631                }
1632                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1633                cx.notify();
1634            })
1635            .ok();
1636
1637            new_prediction_id
1638        });
1639
1640        if project_state.pending_predictions.len() <= 1 {
1641            project_state.pending_predictions.push(PendingPrediction {
1642                id: pending_prediction_id,
1643                task,
1644            });
1645        } else if project_state.pending_predictions.len() == 2 {
1646            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1647            project_state.pending_predictions.push(PendingPrediction {
1648                id: pending_prediction_id,
1649                task,
1650            });
1651            project_state.cancel_pending_prediction(pending_prediction, cx);
1652        }
1653    }
1654
1655    pub fn request_prediction(
1656        &mut self,
1657        project: &Entity<Project>,
1658        active_buffer: &Entity<Buffer>,
1659        position: language::Anchor,
1660        trigger: PredictEditsRequestTrigger,
1661        cx: &mut Context<Self>,
1662    ) -> Task<Result<Option<EditPredictionResult>>> {
1663        self.request_prediction_internal(
1664            project.clone(),
1665            active_buffer.clone(),
1666            position,
1667            trigger,
1668            cx.has_flag::<Zeta2FeatureFlag>(),
1669            cx,
1670        )
1671    }
1672
1673    fn request_prediction_internal(
1674        &mut self,
1675        project: Entity<Project>,
1676        active_buffer: Entity<Buffer>,
1677        position: language::Anchor,
1678        trigger: PredictEditsRequestTrigger,
1679        allow_jump: bool,
1680        cx: &mut Context<Self>,
1681    ) -> Task<Result<Option<EditPredictionResult>>> {
1682        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1683
1684        self.get_or_init_project(&project, cx);
1685        let project_state = self.projects.get(&project.entity_id()).unwrap();
1686        let stored_events = project_state.events(cx);
1687        let has_events = !stored_events.is_empty();
1688        let events: Vec<Arc<zeta_prompt::Event>> =
1689            stored_events.into_iter().map(|e| e.event).collect();
1690        let debug_tx = project_state.debug_tx.clone();
1691
1692        let snapshot = active_buffer.read(cx).snapshot();
1693        let cursor_point = position.to_point(&snapshot);
1694        let current_offset = position.to_offset(&snapshot);
1695
1696        let mut user_actions: Vec<UserActionRecord> =
1697            project_state.user_actions.iter().cloned().collect();
1698
1699        if let Some(last_action) = user_actions.last() {
1700            if last_action.buffer_id == active_buffer.entity_id()
1701                && current_offset != last_action.offset
1702            {
1703                let timestamp_epoch_ms = SystemTime::now()
1704                    .duration_since(UNIX_EPOCH)
1705                    .map(|d| d.as_millis() as u64)
1706                    .unwrap_or(0);
1707                user_actions.push(UserActionRecord {
1708                    action_type: UserActionType::CursorMovement,
1709                    buffer_id: active_buffer.entity_id(),
1710                    line_number: cursor_point.row,
1711                    offset: current_offset,
1712                    timestamp_epoch_ms,
1713                });
1714            }
1715        }
1716        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1717        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1718        let diagnostic_search_range =
1719            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1720
1721        let related_files = if self.use_context {
1722            self.context_for_project(&project, cx)
1723        } else {
1724            Vec::new()
1725        };
1726
1727        let inputs = EditPredictionModelInput {
1728            project: project.clone(),
1729            buffer: active_buffer.clone(),
1730            snapshot: snapshot.clone(),
1731            position,
1732            events,
1733            related_files,
1734            recent_paths: project_state.recent_paths.clone(),
1735            trigger,
1736            diagnostic_search_range: diagnostic_search_range.clone(),
1737            debug_tx,
1738            user_actions,
1739        };
1740
1741        let can_collect_example = snapshot
1742            .file()
1743            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1744            && self.can_collect_events(&inputs.events, cx)
1745            && self.can_collect_related_files(&project, cx);
1746
1747        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1748            let events_for_capture =
1749                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1750            let related_files_for_capture = inputs.related_files.clone();
1751            if let Some(example_task) = capture_example::capture_example(
1752                project.clone(),
1753                active_buffer.clone(),
1754                position,
1755                events_for_capture,
1756                related_files_for_capture,
1757                false,
1758                cx,
1759            ) {
1760                cx.spawn(async move |_this, _cx| {
1761                    let example = example_task.await?;
1762                    telemetry::event!("Edit Prediction Example Captured", example = example);
1763                    anyhow::Ok(())
1764                })
1765                .detach_and_log_err(cx);
1766            }
1767        }
1768        let task = match self.edit_prediction_model {
1769            EditPredictionModel::Zeta1 => {
1770                if should_send_testing_zeta2_request() {
1771                    let mut zeta2_inputs = inputs.clone();
1772                    zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1773                    zeta2::request_prediction_with_zeta2(
1774                        self,
1775                        zeta2_inputs,
1776                        Default::default(),
1777                        cx,
1778                    )
1779                    .detach();
1780                }
1781                zeta1::request_prediction_with_zeta1(self, inputs, cx)
1782            }
1783            EditPredictionModel::Zeta2 { version } => {
1784                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1785            }
1786            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1787            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1788        };
1789
1790        cx.spawn(async move |this, cx| {
1791            let prediction = task.await?;
1792
1793            if prediction.is_none() && allow_jump {
1794                let cursor_point = position.to_point(&snapshot);
1795                if has_events
1796                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1797                        active_buffer.clone(),
1798                        &snapshot,
1799                        diagnostic_search_range,
1800                        cursor_point,
1801                        &project,
1802                        cx,
1803                    )
1804                    .await?
1805                {
1806                    return this
1807                        .update(cx, |this, cx| {
1808                            this.request_prediction_internal(
1809                                project,
1810                                jump_buffer,
1811                                jump_position,
1812                                trigger,
1813                                false,
1814                                cx,
1815                            )
1816                        })?
1817                        .await;
1818                }
1819
1820                return anyhow::Ok(None);
1821            }
1822
1823            Ok(prediction)
1824        })
1825    }
1826
1827    async fn next_diagnostic_location(
1828        active_buffer: Entity<Buffer>,
1829        active_buffer_snapshot: &BufferSnapshot,
1830        active_buffer_diagnostic_search_range: Range<Point>,
1831        active_buffer_cursor_point: Point,
1832        project: &Entity<Project>,
1833        cx: &mut AsyncApp,
1834    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1835        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1836        let mut jump_location = active_buffer_snapshot
1837            .diagnostic_groups(None)
1838            .into_iter()
1839            .filter_map(|(_, group)| {
1840                let range = &group.entries[group.primary_ix]
1841                    .range
1842                    .to_point(&active_buffer_snapshot);
1843                if range.overlaps(&active_buffer_diagnostic_search_range) {
1844                    None
1845                } else {
1846                    Some(range.start)
1847                }
1848            })
1849            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1850            .map(|position| {
1851                (
1852                    active_buffer.clone(),
1853                    active_buffer_snapshot.anchor_before(position),
1854                )
1855            });
1856
1857        if jump_location.is_none() {
1858            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1859                let file = buffer.file()?;
1860
1861                Some(ProjectPath {
1862                    worktree_id: file.worktree_id(cx),
1863                    path: file.path().clone(),
1864                })
1865            });
1866
1867            let buffer_task = project.update(cx, |project, cx| {
1868                let (path, _, _) = project
1869                    .diagnostic_summaries(false, cx)
1870                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1871                    .max_by_key(|(path, _, _)| {
1872                        // find the buffer with errors that shares most parent directories
1873                        path.path
1874                            .components()
1875                            .zip(
1876                                active_buffer_path
1877                                    .as_ref()
1878                                    .map(|p| p.path.components())
1879                                    .unwrap_or_default(),
1880                            )
1881                            .take_while(|(a, b)| a == b)
1882                            .count()
1883                    })?;
1884
1885                Some(project.open_buffer(path, cx))
1886            });
1887
1888            if let Some(buffer_task) = buffer_task {
1889                let closest_buffer = buffer_task.await?;
1890
1891                jump_location = closest_buffer
1892                    .read_with(cx, |buffer, _cx| {
1893                        buffer
1894                            .buffer_diagnostics(None)
1895                            .into_iter()
1896                            .min_by_key(|entry| entry.diagnostic.severity)
1897                            .map(|entry| entry.range.start)
1898                    })
1899                    .map(|position| (closest_buffer, position));
1900            }
1901        }
1902
1903        anyhow::Ok(jump_location)
1904    }
1905
1906    async fn send_raw_llm_request(
1907        request: RawCompletionRequest,
1908        client: Arc<Client>,
1909        custom_url: Option<Arc<Url>>,
1910        llm_token: LlmApiToken,
1911        app_version: Version,
1912    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1913        let url = if let Some(custom_url) = custom_url {
1914            custom_url.as_ref().clone()
1915        } else {
1916            client
1917                .http_client()
1918                .build_zed_llm_url("/predict_edits/raw", &[])?
1919        };
1920
1921        Self::send_api_request(
1922            |builder| {
1923                let req = builder
1924                    .uri(url.as_ref())
1925                    .body(serde_json::to_string(&request)?.into());
1926                Ok(req?)
1927            },
1928            client,
1929            llm_token,
1930            app_version,
1931            true,
1932        )
1933        .await
1934    }
1935
1936    pub(crate) async fn send_v3_request(
1937        input: ZetaPromptInput,
1938        prompt_version: ZetaVersion,
1939        client: Arc<Client>,
1940        llm_token: LlmApiToken,
1941        app_version: Version,
1942        trigger: PredictEditsRequestTrigger,
1943    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1944        let url = client
1945            .http_client()
1946            .build_zed_llm_url("/predict_edits/v3", &[])?;
1947
1948        let request = PredictEditsV3Request {
1949            input,
1950            model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1951            prompt_version,
1952            trigger,
1953        };
1954
1955        Self::send_api_request(
1956            |builder| {
1957                let req = builder
1958                    .uri(url.as_ref())
1959                    .body(serde_json::to_string(&request)?.into());
1960                Ok(req?)
1961            },
1962            client,
1963            llm_token,
1964            app_version,
1965            true,
1966        )
1967        .await
1968    }
1969
1970    fn handle_api_response<T>(
1971        this: &WeakEntity<Self>,
1972        response: Result<(T, Option<EditPredictionUsage>)>,
1973        cx: &mut gpui::AsyncApp,
1974    ) -> Result<T> {
1975        match response {
1976            Ok((data, usage)) => {
1977                if let Some(usage) = usage {
1978                    this.update(cx, |this, cx| {
1979                        this.user_store.update(cx, |user_store, cx| {
1980                            user_store.update_edit_prediction_usage(usage, cx);
1981                        });
1982                    })
1983                    .ok();
1984                }
1985                Ok(data)
1986            }
1987            Err(err) => {
1988                if err.is::<ZedUpdateRequiredError>() {
1989                    cx.update(|cx| {
1990                        this.update(cx, |this, _cx| {
1991                            this.update_required = true;
1992                        })
1993                        .ok();
1994
1995                        let error_message: SharedString = err.to_string().into();
1996                        show_app_notification(
1997                            NotificationId::unique::<ZedUpdateRequiredError>(),
1998                            cx,
1999                            move |cx| {
2000                                cx.new(|cx| {
2001                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2002                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2003                                })
2004                            },
2005                        );
2006                    });
2007                }
2008                Err(err)
2009            }
2010        }
2011    }
2012
2013    async fn send_api_request<Res>(
2014        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2015        client: Arc<Client>,
2016        llm_token: LlmApiToken,
2017        app_version: Version,
2018        require_auth: bool,
2019    ) -> Result<(Res, Option<EditPredictionUsage>)>
2020    where
2021        Res: DeserializeOwned,
2022    {
2023        let http_client = client.http_client();
2024
2025        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2026            Some(custom_token)
2027        } else if require_auth {
2028            Some(llm_token.acquire(&client).await?)
2029        } else {
2030            llm_token.acquire(&client).await.ok()
2031        };
2032        let mut did_retry = false;
2033
2034        loop {
2035            let request_builder = http_client::Request::builder().method(Method::POST);
2036
2037            let mut request_builder = request_builder
2038                .header("Content-Type", "application/json")
2039                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2040
2041            // Only add Authorization header if we have a token
2042            if let Some(ref token_value) = token {
2043                request_builder =
2044                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2045            }
2046
2047            let request = build(request_builder)?;
2048
2049            let mut response = http_client.send(request).await?;
2050
2051            if let Some(minimum_required_version) = response
2052                .headers()
2053                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2054                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2055            {
2056                anyhow::ensure!(
2057                    app_version >= minimum_required_version,
2058                    ZedUpdateRequiredError {
2059                        minimum_version: minimum_required_version
2060                    }
2061                );
2062            }
2063
2064            if response.status().is_success() {
2065                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2066
2067                let mut body = Vec::new();
2068                response.body_mut().read_to_end(&mut body).await?;
2069                return Ok((serde_json::from_slice(&body)?, usage));
2070            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2071                did_retry = true;
2072                token = Some(llm_token.refresh(&client).await?);
2073            } else {
2074                let mut body = String::new();
2075                response.body_mut().read_to_string(&mut body).await?;
2076                anyhow::bail!(
2077                    "Request failed with status: {:?}\nBody: {}",
2078                    response.status(),
2079                    body
2080                );
2081            }
2082        }
2083    }
2084
2085    pub fn refresh_context(
2086        &mut self,
2087        project: &Entity<Project>,
2088        buffer: &Entity<language::Buffer>,
2089        cursor_position: language::Anchor,
2090        cx: &mut Context<Self>,
2091    ) {
2092        if self.use_context {
2093            self.get_or_init_project(project, cx)
2094                .context
2095                .update(cx, |store, cx| {
2096                    store.refresh(buffer.clone(), cursor_position, cx);
2097                });
2098        }
2099    }
2100
2101    #[cfg(feature = "cli-support")]
2102    pub fn set_context_for_buffer(
2103        &mut self,
2104        project: &Entity<Project>,
2105        related_files: Vec<RelatedFile>,
2106        cx: &mut Context<Self>,
2107    ) {
2108        self.get_or_init_project(project, cx)
2109            .context
2110            .update(cx, |store, cx| {
2111                store.set_related_files(related_files, cx);
2112            });
2113    }
2114
2115    #[cfg(feature = "cli-support")]
2116    pub fn set_recent_paths_for_project(
2117        &mut self,
2118        project: &Entity<Project>,
2119        paths: impl IntoIterator<Item = project::ProjectPath>,
2120        cx: &mut Context<Self>,
2121    ) {
2122        let project_state = self.get_or_init_project(project, cx);
2123        project_state.recent_paths = paths.into_iter().collect();
2124    }
2125
2126    fn is_file_open_source(
2127        &self,
2128        project: &Entity<Project>,
2129        file: &Arc<dyn File>,
2130        cx: &App,
2131    ) -> bool {
2132        if !file.is_local() || file.is_private() {
2133            return false;
2134        }
2135        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2136            return false;
2137        };
2138        project_state
2139            .license_detection_watchers
2140            .get(&file.worktree_id(cx))
2141            .as_ref()
2142            .is_some_and(|watcher| watcher.is_project_open_source())
2143    }
2144
2145    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2146        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2147    }
2148
2149    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2150        if !self.data_collection_choice.is_enabled(cx) {
2151            return false;
2152        }
2153        events.iter().all(|event| {
2154            matches!(
2155                event.as_ref(),
2156                zeta_prompt::Event::BufferChange {
2157                    in_open_source_repo: true,
2158                    ..
2159                }
2160            )
2161        })
2162    }
2163
2164    fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2165        if !self.data_collection_choice.is_enabled(cx) {
2166            return false;
2167        }
2168
2169        let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2170
2171        related_with_buffers.iter().all(|(_, buffer)| {
2172            buffer
2173                .read(cx)
2174                .file()
2175                .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2176        })
2177    }
2178
2179    fn load_data_collection_choice() -> DataCollectionChoice {
2180        let choice = KEY_VALUE_STORE
2181            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2182            .log_err()
2183            .flatten();
2184
2185        match choice.as_deref() {
2186            Some("true") => DataCollectionChoice::Enabled,
2187            Some("false") => DataCollectionChoice::Disabled,
2188            Some(_) => {
2189                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2190                DataCollectionChoice::NotAnswered
2191            }
2192            None => DataCollectionChoice::NotAnswered,
2193        }
2194    }
2195
2196    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2197        self.data_collection_choice = self.data_collection_choice.toggle();
2198        let new_choice = self.data_collection_choice;
2199        let is_enabled = new_choice.is_enabled(cx);
2200        db::write_and_log(cx, move || {
2201            KEY_VALUE_STORE.write_kvp(
2202                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2203                is_enabled.to_string(),
2204            )
2205        });
2206    }
2207
2208    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2209        self.shown_predictions.iter()
2210    }
2211
2212    pub fn shown_completions_len(&self) -> usize {
2213        self.shown_predictions.len()
2214    }
2215
2216    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2217        self.rated_predictions.contains(id)
2218    }
2219
2220    pub fn rate_prediction(
2221        &mut self,
2222        prediction: &EditPrediction,
2223        rating: EditPredictionRating,
2224        feedback: String,
2225        cx: &mut Context<Self>,
2226    ) {
2227        self.rated_predictions.insert(prediction.id.clone());
2228        telemetry::event!(
2229            "Edit Prediction Rated",
2230            rating,
2231            inputs = prediction.inputs,
2232            output = prediction
2233                .edit_preview
2234                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2235            feedback
2236        );
2237        self.client.telemetry().flush_events().detach();
2238        cx.notify();
2239    }
2240
2241    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2242        if cfg!(feature = "cli-support") {
2243            return;
2244        }
2245        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2246            && all_language_settings(None, cx).edit_predictions.use_context;
2247    }
2248}
2249
2250pub(crate) fn filter_redundant_excerpts(
2251    mut related_files: Vec<RelatedFile>,
2252    cursor_path: &Path,
2253    cursor_row_range: Range<u32>,
2254) -> Vec<RelatedFile> {
2255    for file in &mut related_files {
2256        if file.path.as_ref() == cursor_path {
2257            file.excerpts.retain(|excerpt| {
2258                excerpt.row_range.start < cursor_row_range.start
2259                    || excerpt.row_range.end > cursor_row_range.end
2260            });
2261        }
2262    }
2263    related_files.retain(|file| !file.excerpts.is_empty());
2264    related_files
2265}
2266
2267#[derive(Error, Debug)]
2268#[error(
2269    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2270)]
2271pub struct ZedUpdateRequiredError {
2272    minimum_version: Version,
2273}
2274
2275#[derive(Debug, Clone, Copy)]
2276pub enum DataCollectionChoice {
2277    NotAnswered,
2278    Enabled,
2279    Disabled,
2280}
2281
2282impl DataCollectionChoice {
2283    pub fn is_enabled(self, cx: &App) -> bool {
2284        if cx.is_staff() {
2285            return true;
2286        }
2287        match self {
2288            Self::Enabled => true,
2289            Self::NotAnswered | Self::Disabled => false,
2290        }
2291    }
2292
2293    #[must_use]
2294    pub fn toggle(&self) -> DataCollectionChoice {
2295        match self {
2296            Self::Enabled => Self::Disabled,
2297            Self::Disabled => Self::Enabled,
2298            Self::NotAnswered => Self::Enabled,
2299        }
2300    }
2301}
2302
2303impl From<bool> for DataCollectionChoice {
2304    fn from(value: bool) -> Self {
2305        match value {
2306            true => DataCollectionChoice::Enabled,
2307            false => DataCollectionChoice::Disabled,
2308        }
2309    }
2310}
2311
2312struct ZedPredictUpsell;
2313
2314impl Dismissable for ZedPredictUpsell {
2315    const KEY: &'static str = "dismissed-edit-predict-upsell";
2316
2317    fn dismissed() -> bool {
2318        // To make this backwards compatible with older versions of Zed, we
2319        // check if the user has seen the previous Edit Prediction Onboarding
2320        // before, by checking the data collection choice which was written to
2321        // the database once the user clicked on "Accept and Enable"
2322        if KEY_VALUE_STORE
2323            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2324            .log_err()
2325            .is_some_and(|s| s.is_some())
2326        {
2327            return true;
2328        }
2329
2330        KEY_VALUE_STORE
2331            .read_kvp(Self::KEY)
2332            .log_err()
2333            .is_some_and(|s| s.is_some())
2334    }
2335}
2336
2337pub fn should_show_upsell_modal() -> bool {
2338    !ZedPredictUpsell::dismissed()
2339}
2340
2341pub fn init(cx: &mut App) {
2342    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2343        workspace.register_action(
2344            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2345                ZedPredictModal::toggle(
2346                    workspace,
2347                    workspace.user_store().clone(),
2348                    workspace.client().clone(),
2349                    window,
2350                    cx,
2351                )
2352            },
2353        );
2354
2355        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2356            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2357                settings
2358                    .project
2359                    .all_languages
2360                    .features
2361                    .get_or_insert_default()
2362                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2363            });
2364        });
2365        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2366            EditPredictionStore::try_global(cx).and_then(|store| {
2367                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2368            })
2369        }
2370
2371        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2372            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2373                copilot_ui::initiate_sign_in(copilot, window, cx);
2374            }
2375        });
2376        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2377            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2378                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2379            }
2380        });
2381        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2382            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2383                copilot_ui::initiate_sign_out(copilot, window, cx);
2384            }
2385        });
2386    })
2387    .detach();
2388}