edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use text::Edit;
  40use workspace::Workspace;
  41use zeta_prompt::ZetaPromptInput;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59pub mod ollama;
  60mod onboarding_modal;
  61pub mod open_ai_response;
  62mod prediction;
  63pub mod sweep_ai;
  64
  65pub mod udiff;
  66
  67mod capture_example;
  68mod zed_edit_prediction_delegate;
  69pub mod zeta1;
  70pub mod zeta2;
  71
  72#[cfg(test)]
  73mod edit_prediction_tests;
  74
  75use crate::capture_example::{
  76    should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
  77};
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::ollama::Ollama;
  81use crate::onboarding_modal::ZedPredictModal;
  82pub use crate::prediction::EditPrediction;
  83pub use crate::prediction::EditPredictionId;
  84use crate::prediction::EditPredictionResult;
  85pub use crate::sweep_ai::SweepAi;
  86pub use capture_example::capture_example;
  87pub use language_model::ApiKeyState;
  88pub use telemetry_events::EditPredictionRating;
  89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  90
  91actions!(
  92    edit_prediction,
  93    [
  94        /// Resets the edit prediction onboarding state.
  95        ResetOnboarding,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 107
 108static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
 109    LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
 110
 111pub struct Zeta2FeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115
 116    fn enabled_for_staff() -> bool {
 117        true
 118    }
 119}
 120
 121pub struct EditPredictionExampleCaptureFeatureFlag;
 122
 123impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 124    const NAME: &'static str = "edit-prediction-example-capture";
 125
 126    fn enabled_for_staff() -> bool {
 127        true
 128    }
 129}
 130
 131#[derive(Clone)]
 132struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 133
 134impl Global for EditPredictionStoreGlobal {}
 135
 136pub struct EditPredictionStore {
 137    client: Arc<Client>,
 138    user_store: Entity<UserStore>,
 139    llm_token: LlmApiToken,
 140    _llm_token_subscription: Subscription,
 141    projects: HashMap<EntityId, ProjectState>,
 142    update_required: bool,
 143    edit_prediction_model: EditPredictionModel,
 144    pub sweep_ai: SweepAi,
 145    pub mercury: Mercury,
 146    pub ollama: Ollama,
 147    data_collection_choice: DataCollectionChoice,
 148    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    custom_predict_edits_url: Option<Arc<Url>>,
 152}
 153
 154#[derive(Copy, Clone, Default, PartialEq, Eq)]
 155pub enum EditPredictionModel {
 156    #[default]
 157    Zeta1,
 158    Zeta2 {
 159        version: ZetaVersion,
 160    },
 161    Sweep,
 162    Mercury,
 163    Ollama,
 164}
 165
 166#[derive(Clone)]
 167pub struct EditPredictionModelInput {
 168    project: Entity<Project>,
 169    buffer: Entity<Buffer>,
 170    snapshot: BufferSnapshot,
 171    position: Anchor,
 172    events: Vec<Arc<zeta_prompt::Event>>,
 173    related_files: Vec<RelatedFile>,
 174    recent_paths: VecDeque<ProjectPath>,
 175    trigger: PredictEditsRequestTrigger,
 176    diagnostic_search_range: Range<Point>,
 177    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 178    pub user_actions: Vec<UserActionRecord>,
 179}
 180
 181#[derive(Debug)]
 182pub enum DebugEvent {
 183    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 184    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 185    EditPredictionStarted(EditPredictionStartedDebugEvent),
 186    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 187}
 188
 189#[derive(Debug)]
 190pub struct ContextRetrievalStartedDebugEvent {
 191    pub project_entity_id: EntityId,
 192    pub timestamp: Instant,
 193    pub search_prompt: String,
 194}
 195
 196#[derive(Debug)]
 197pub struct ContextRetrievalFinishedDebugEvent {
 198    pub project_entity_id: EntityId,
 199    pub timestamp: Instant,
 200    pub metadata: Vec<(&'static str, SharedString)>,
 201}
 202
 203#[derive(Debug)]
 204pub struct EditPredictionStartedDebugEvent {
 205    pub buffer: WeakEntity<Buffer>,
 206    pub position: Anchor,
 207    pub prompt: Option<String>,
 208}
 209
 210#[derive(Debug)]
 211pub struct EditPredictionFinishedDebugEvent {
 212    pub buffer: WeakEntity<Buffer>,
 213    pub position: Anchor,
 214    pub model_output: Option<String>,
 215}
 216
 217const USER_ACTION_HISTORY_SIZE: usize = 16;
 218
 219#[derive(Clone, Debug)]
 220pub struct UserActionRecord {
 221    pub action_type: UserActionType,
 222    pub buffer_id: EntityId,
 223    pub line_number: u32,
 224    pub offset: usize,
 225    pub timestamp_epoch_ms: u64,
 226}
 227
 228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 229pub enum UserActionType {
 230    InsertChar,
 231    InsertSelection,
 232    DeleteChar,
 233    DeleteSelection,
 234    CursorMovement,
 235}
 236
 237/// An event with associated metadata for reconstructing buffer state.
 238#[derive(Clone)]
 239pub struct StoredEvent {
 240    pub event: Arc<zeta_prompt::Event>,
 241    pub old_snapshot: TextBufferSnapshot,
 242}
 243
 244struct ProjectState {
 245    events: VecDeque<StoredEvent>,
 246    last_event: Option<LastEvent>,
 247    recent_paths: VecDeque<ProjectPath>,
 248    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 249    current_prediction: Option<CurrentEditPrediction>,
 250    next_pending_prediction_id: usize,
 251    pending_predictions: ArrayVec<PendingPrediction, 2>,
 252    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 253    last_prediction_refresh: Option<(EntityId, Instant)>,
 254    cancelled_predictions: HashSet<usize>,
 255    context: Entity<RelatedExcerptStore>,
 256    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 257    user_actions: VecDeque<UserActionRecord>,
 258    _subscriptions: [gpui::Subscription; 2],
 259    copilot: Option<Entity<Copilot>>,
 260}
 261
 262impl ProjectState {
 263    fn record_user_action(&mut self, action: UserActionRecord) {
 264        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 265            self.user_actions.pop_front();
 266        }
 267        self.user_actions.push_back(action);
 268    }
 269
 270    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 271        self.events
 272            .iter()
 273            .cloned()
 274            .chain(
 275                self.last_event
 276                    .as_ref()
 277                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 278            )
 279            .collect()
 280    }
 281
 282    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 283        self.events
 284            .iter()
 285            .cloned()
 286            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 287                let (one, two) = event.split_by_pause();
 288                let one = one.finalize(&self.license_detection_watchers, cx);
 289                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 290                one.into_iter().chain(two)
 291            }))
 292            .collect()
 293    }
 294
 295    fn cancel_pending_prediction(
 296        &mut self,
 297        pending_prediction: PendingPrediction,
 298        cx: &mut Context<EditPredictionStore>,
 299    ) {
 300        self.cancelled_predictions.insert(pending_prediction.id);
 301
 302        if pending_prediction.drop_on_cancel {
 303            drop(pending_prediction.task);
 304        } else {
 305            cx.spawn(async move |this, cx| {
 306                let Some(prediction_id) = pending_prediction.task.await else {
 307                    return;
 308                };
 309
 310                this.update(cx, |this, cx| {
 311                    this.reject_prediction(
 312                        prediction_id,
 313                        EditPredictionRejectReason::Canceled,
 314                        false,
 315                        edit_prediction_types::EditPredictionDismissReason::Ignored,
 316                        cx,
 317                    );
 318                })
 319                .ok();
 320            })
 321            .detach()
 322        }
 323    }
 324
 325    fn active_buffer(
 326        &self,
 327        project: &Entity<Project>,
 328        cx: &App,
 329    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 330        let project = project.read(cx);
 331        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 332        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 333        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 334        Some((active_buffer, registered_buffer.last_position))
 335    }
 336}
 337
 338#[derive(Debug, Clone)]
 339struct CurrentEditPrediction {
 340    pub requested_by: PredictionRequestedBy,
 341    pub prediction: EditPrediction,
 342    pub was_shown: bool,
 343    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 344}
 345
 346impl CurrentEditPrediction {
 347    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 348        let Some(new_edits) = self
 349            .prediction
 350            .interpolate(&self.prediction.buffer.read(cx))
 351        else {
 352            return false;
 353        };
 354
 355        if self.prediction.buffer != old_prediction.prediction.buffer {
 356            return true;
 357        }
 358
 359        let Some(old_edits) = old_prediction
 360            .prediction
 361            .interpolate(&old_prediction.prediction.buffer.read(cx))
 362        else {
 363            return true;
 364        };
 365
 366        let requested_by_buffer_id = self.requested_by.buffer_id();
 367
 368        // This reduces the occurrence of UI thrash from replacing edits
 369        //
 370        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 371        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 372            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 373            && old_edits.len() == 1
 374            && new_edits.len() == 1
 375        {
 376            let (old_range, old_text) = &old_edits[0];
 377            let (new_range, new_text) = &new_edits[0];
 378            new_range == old_range && new_text.starts_with(old_text.as_ref())
 379        } else {
 380            true
 381        }
 382    }
 383}
 384
 385#[derive(Debug, Clone)]
 386enum PredictionRequestedBy {
 387    DiagnosticsUpdate,
 388    Buffer(EntityId),
 389}
 390
 391impl PredictionRequestedBy {
 392    pub fn buffer_id(&self) -> Option<EntityId> {
 393        match self {
 394            PredictionRequestedBy::DiagnosticsUpdate => None,
 395            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 396        }
 397    }
 398}
 399
 400#[derive(Debug)]
 401struct PendingPrediction {
 402    id: usize,
 403    task: Task<Option<EditPredictionId>>,
 404    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 405    /// If false, the task is awaited to completion so rejection can be reported.
 406    drop_on_cancel: bool,
 407}
 408
 409/// A prediction from the perspective of a buffer.
 410#[derive(Debug)]
 411enum BufferEditPrediction<'a> {
 412    Local { prediction: &'a EditPrediction },
 413    Jump { prediction: &'a EditPrediction },
 414}
 415
 416#[cfg(test)]
 417impl std::ops::Deref for BufferEditPrediction<'_> {
 418    type Target = EditPrediction;
 419
 420    fn deref(&self) -> &Self::Target {
 421        match self {
 422            BufferEditPrediction::Local { prediction } => prediction,
 423            BufferEditPrediction::Jump { prediction } => prediction,
 424        }
 425    }
 426}
 427
 428struct RegisteredBuffer {
 429    file: Option<Arc<dyn File>>,
 430    snapshot: TextBufferSnapshot,
 431    last_position: Option<Anchor>,
 432    _subscriptions: [gpui::Subscription; 2],
 433}
 434
 435#[derive(Clone)]
 436struct LastEvent {
 437    old_snapshot: TextBufferSnapshot,
 438    new_snapshot: TextBufferSnapshot,
 439    old_file: Option<Arc<dyn File>>,
 440    new_file: Option<Arc<dyn File>>,
 441    edit_range: Option<Range<Anchor>>,
 442    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 443    last_edit_time: Option<Instant>,
 444}
 445
 446impl LastEvent {
 447    pub fn finalize(
 448        &self,
 449        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 450        cx: &App,
 451    ) -> Option<StoredEvent> {
 452        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 453        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 454
 455        let in_open_source_repo =
 456            [self.new_file.as_ref(), self.old_file.as_ref()]
 457                .iter()
 458                .all(|file| {
 459                    file.is_some_and(|file| {
 460                        license_detection_watchers
 461                            .get(&file.worktree_id(cx))
 462                            .is_some_and(|watcher| watcher.is_project_open_source())
 463                    })
 464                });
 465
 466        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 467
 468        if path == old_path && diff.is_empty() {
 469            None
 470        } else {
 471            Some(StoredEvent {
 472                event: Arc::new(zeta_prompt::Event::BufferChange {
 473                    old_path,
 474                    path,
 475                    diff,
 476                    in_open_source_repo,
 477                    // TODO: Actually detect if this edit was predicted or not
 478                    predicted: false,
 479                }),
 480                old_snapshot: self.old_snapshot.clone(),
 481            })
 482        }
 483    }
 484
 485    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 486        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 487            return (self.clone(), None);
 488        };
 489
 490        let before = LastEvent {
 491            old_snapshot: self.old_snapshot.clone(),
 492            new_snapshot: boundary_snapshot.clone(),
 493            old_file: self.old_file.clone(),
 494            new_file: self.new_file.clone(),
 495            edit_range: None,
 496            snapshot_after_last_editing_pause: None,
 497            last_edit_time: self.last_edit_time,
 498        };
 499
 500        let after = LastEvent {
 501            old_snapshot: boundary_snapshot.clone(),
 502            new_snapshot: self.new_snapshot.clone(),
 503            old_file: self.old_file.clone(),
 504            new_file: self.new_file.clone(),
 505            edit_range: None,
 506            snapshot_after_last_editing_pause: None,
 507            last_edit_time: self.last_edit_time,
 508        };
 509
 510        (before, Some(after))
 511    }
 512}
 513
 514pub(crate) fn compute_diff_between_snapshots(
 515    old_snapshot: &TextBufferSnapshot,
 516    new_snapshot: &TextBufferSnapshot,
 517) -> Option<String> {
 518    let edits: Vec<Edit<usize>> = new_snapshot
 519        .edits_since::<usize>(&old_snapshot.version)
 520        .collect();
 521
 522    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 523
 524    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 525    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 526    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 527    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 528
 529    const CONTEXT_LINES: u32 = 3;
 530
 531    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 532    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 533    let old_context_end_row =
 534        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 535    let new_context_end_row =
 536        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 537
 538    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 539    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 540    let old_end_line_offset = old_snapshot
 541        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 542    let new_end_line_offset = new_snapshot
 543        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 544    let old_edit_range = old_start_line_offset..old_end_line_offset;
 545    let new_edit_range = new_start_line_offset..new_end_line_offset;
 546
 547    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 548    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 549
 550    let diff = language::unified_diff_with_offsets(
 551        &old_region_text,
 552        &new_region_text,
 553        old_context_start_row,
 554        new_context_start_row,
 555    );
 556
 557    Some(diff)
 558}
 559
 560fn buffer_path_with_id_fallback(
 561    file: Option<&Arc<dyn File>>,
 562    snapshot: &TextBufferSnapshot,
 563    cx: &App,
 564) -> Arc<Path> {
 565    if let Some(file) = file {
 566        file.full_path(cx).into()
 567    } else {
 568        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 569    }
 570}
 571
 572impl EditPredictionStore {
 573    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 574        cx.try_global::<EditPredictionStoreGlobal>()
 575            .map(|global| global.0.clone())
 576    }
 577
 578    pub fn global(
 579        client: &Arc<Client>,
 580        user_store: &Entity<UserStore>,
 581        cx: &mut App,
 582    ) -> Entity<Self> {
 583        cx.try_global::<EditPredictionStoreGlobal>()
 584            .map(|global| global.0.clone())
 585            .unwrap_or_else(|| {
 586                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 587                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 588                ep_store
 589            })
 590    }
 591
 592    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 593        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 594        let data_collection_choice = Self::load_data_collection_choice();
 595
 596        let llm_token = LlmApiToken::default();
 597
 598        let (reject_tx, reject_rx) = mpsc::unbounded();
 599        cx.background_spawn({
 600            let client = client.clone();
 601            let llm_token = llm_token.clone();
 602            let app_version = AppVersion::global(cx);
 603            let background_executor = cx.background_executor().clone();
 604            async move {
 605                Self::handle_rejected_predictions(
 606                    reject_rx,
 607                    client,
 608                    llm_token,
 609                    app_version,
 610                    background_executor,
 611                )
 612                .await
 613            }
 614        })
 615        .detach();
 616
 617        let this = Self {
 618            projects: HashMap::default(),
 619            client,
 620            user_store,
 621            llm_token,
 622            _llm_token_subscription: cx.subscribe(
 623                &refresh_llm_token_listener,
 624                |this, _listener, _event, cx| {
 625                    let client = this.client.clone();
 626                    let llm_token = this.llm_token.clone();
 627                    cx.spawn(async move |_this, _cx| {
 628                        llm_token.refresh(&client).await?;
 629                        anyhow::Ok(())
 630                    })
 631                    .detach_and_log_err(cx);
 632                },
 633            ),
 634            update_required: false,
 635            edit_prediction_model: EditPredictionModel::Zeta2 {
 636                version: Default::default(),
 637            },
 638            sweep_ai: SweepAi::new(cx),
 639            mercury: Mercury::new(cx),
 640            ollama: Ollama::new(),
 641
 642            data_collection_choice,
 643            reject_predictions_tx: reject_tx,
 644            rated_predictions: Default::default(),
 645            shown_predictions: Default::default(),
 646            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 647                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 648                Err(_) => None,
 649            },
 650        };
 651
 652        this
 653    }
 654
 655    #[cfg(test)]
 656    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 657        self.custom_predict_edits_url = Some(url.into());
 658    }
 659
 660    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 661        self.edit_prediction_model = model;
 662    }
 663
 664    pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
 665        use ui::IconName;
 666        match self.edit_prediction_model {
 667            EditPredictionModel::Ollama => {
 668                edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 669            }
 670            EditPredictionModel::Sweep => {
 671                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 672                    .with_disabled(IconName::SweepAiDisabled)
 673                    .with_up(IconName::SweepAiUp)
 674                    .with_down(IconName::SweepAiDown)
 675                    .with_error(IconName::SweepAiError)
 676            }
 677            EditPredictionModel::Mercury => {
 678                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 679            }
 680            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
 681                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 682                    .with_disabled(IconName::ZedPredictDisabled)
 683                    .with_up(IconName::ZedPredictUp)
 684                    .with_down(IconName::ZedPredictDown)
 685                    .with_error(IconName::ZedPredictError)
 686            }
 687        }
 688    }
 689
 690    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 691        self.sweep_ai.api_token.read(cx).has_key()
 692    }
 693
 694    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 695        self.mercury.api_token.read(cx).has_key()
 696    }
 697
 698    pub fn clear_history(&mut self) {
 699        for project_state in self.projects.values_mut() {
 700            project_state.events.clear();
 701            project_state.last_event.take();
 702        }
 703    }
 704
 705    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 706        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 707            project_state.events.clear();
 708            project_state.last_event.take();
 709        }
 710    }
 711
 712    pub fn edit_history_for_project(
 713        &self,
 714        project: &Entity<Project>,
 715        cx: &App,
 716    ) -> Vec<StoredEvent> {
 717        self.projects
 718            .get(&project.entity_id())
 719            .map(|project_state| project_state.events(cx))
 720            .unwrap_or_default()
 721    }
 722
 723    pub fn edit_history_for_project_with_pause_split_last_event(
 724        &self,
 725        project: &Entity<Project>,
 726        cx: &App,
 727    ) -> Vec<StoredEvent> {
 728        self.projects
 729            .get(&project.entity_id())
 730            .map(|project_state| project_state.events_split_by_pause(cx))
 731            .unwrap_or_default()
 732    }
 733
 734    pub fn context_for_project<'a>(
 735        &'a self,
 736        project: &Entity<Project>,
 737        cx: &'a mut App,
 738    ) -> Vec<RelatedFile> {
 739        self.projects
 740            .get(&project.entity_id())
 741            .map(|project| {
 742                project
 743                    .context
 744                    .update(cx, |context, cx| context.related_files(cx))
 745            })
 746            .unwrap_or_default()
 747    }
 748
 749    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 750        self.projects
 751            .get(&project.entity_id())
 752            .and_then(|project| project.copilot.clone())
 753    }
 754
 755    pub fn start_copilot_for_project(
 756        &mut self,
 757        project: &Entity<Project>,
 758        cx: &mut Context<Self>,
 759    ) -> Option<Entity<Copilot>> {
 760        if DisableAiSettings::get(None, cx).disable_ai {
 761            return None;
 762        }
 763        let state = self.get_or_init_project(project, cx);
 764
 765        if state.copilot.is_some() {
 766            return state.copilot.clone();
 767        }
 768        let _project = project.clone();
 769        let project = project.read(cx);
 770
 771        let node = project.node_runtime().cloned();
 772        if let Some(node) = node {
 773            let next_id = project.languages().next_language_server_id();
 774            let fs = project.fs().clone();
 775
 776            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 777            state.copilot = Some(copilot.clone());
 778            Some(copilot)
 779        } else {
 780            None
 781        }
 782    }
 783
 784    pub fn context_for_project_with_buffers<'a>(
 785        &'a self,
 786        project: &Entity<Project>,
 787        cx: &'a mut App,
 788    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 789        self.projects
 790            .get(&project.entity_id())
 791            .map(|project| {
 792                project
 793                    .context
 794                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 795            })
 796            .unwrap_or_default()
 797    }
 798
 799    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 800        if matches!(
 801            self.edit_prediction_model,
 802            EditPredictionModel::Zeta2 { .. }
 803        ) {
 804            self.user_store.read(cx).edit_prediction_usage()
 805        } else {
 806            None
 807        }
 808    }
 809
 810    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 811        self.get_or_init_project(project, cx);
 812    }
 813
 814    pub fn register_buffer(
 815        &mut self,
 816        buffer: &Entity<Buffer>,
 817        project: &Entity<Project>,
 818        cx: &mut Context<Self>,
 819    ) {
 820        let project_state = self.get_or_init_project(project, cx);
 821        Self::register_buffer_impl(project_state, buffer, project, cx);
 822    }
 823
 824    fn get_or_init_project(
 825        &mut self,
 826        project: &Entity<Project>,
 827        cx: &mut Context<Self>,
 828    ) -> &mut ProjectState {
 829        let entity_id = project.entity_id();
 830        self.projects
 831            .entry(entity_id)
 832            .or_insert_with(|| ProjectState {
 833                context: {
 834                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 835                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 836                        this.handle_excerpt_store_event(entity_id, event);
 837                    })
 838                    .detach();
 839                    related_excerpt_store
 840                },
 841                events: VecDeque::new(),
 842                last_event: None,
 843                recent_paths: VecDeque::new(),
 844                debug_tx: None,
 845                registered_buffers: HashMap::default(),
 846                current_prediction: None,
 847                cancelled_predictions: HashSet::default(),
 848                pending_predictions: ArrayVec::new(),
 849                next_pending_prediction_id: 0,
 850                last_prediction_refresh: None,
 851                license_detection_watchers: HashMap::default(),
 852                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 853                _subscriptions: [
 854                    cx.subscribe(&project, Self::handle_project_event),
 855                    cx.observe_release(&project, move |this, _, cx| {
 856                        this.projects.remove(&entity_id);
 857                        cx.notify();
 858                    }),
 859                ],
 860                copilot: None,
 861            })
 862    }
 863
 864    pub fn remove_project(&mut self, project: &Entity<Project>) {
 865        self.projects.remove(&project.entity_id());
 866    }
 867
 868    fn handle_excerpt_store_event(
 869        &mut self,
 870        project_entity_id: EntityId,
 871        event: &RelatedExcerptStoreEvent,
 872    ) {
 873        if let Some(project_state) = self.projects.get(&project_entity_id) {
 874            if let Some(debug_tx) = project_state.debug_tx.clone() {
 875                match event {
 876                    RelatedExcerptStoreEvent::StartedRefresh => {
 877                        debug_tx
 878                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 879                                ContextRetrievalStartedDebugEvent {
 880                                    project_entity_id: project_entity_id,
 881                                    timestamp: Instant::now(),
 882                                    search_prompt: String::new(),
 883                                },
 884                            ))
 885                            .ok();
 886                    }
 887                    RelatedExcerptStoreEvent::FinishedRefresh {
 888                        cache_hit_count,
 889                        cache_miss_count,
 890                        mean_definition_latency,
 891                        max_definition_latency,
 892                    } => {
 893                        debug_tx
 894                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 895                                ContextRetrievalFinishedDebugEvent {
 896                                    project_entity_id: project_entity_id,
 897                                    timestamp: Instant::now(),
 898                                    metadata: vec![
 899                                        (
 900                                            "Cache Hits",
 901                                            format!(
 902                                                "{}/{}",
 903                                                cache_hit_count,
 904                                                cache_hit_count + cache_miss_count
 905                                            )
 906                                            .into(),
 907                                        ),
 908                                        (
 909                                            "Max LSP Time",
 910                                            format!("{} ms", max_definition_latency.as_millis())
 911                                                .into(),
 912                                        ),
 913                                        (
 914                                            "Mean LSP Time",
 915                                            format!("{} ms", mean_definition_latency.as_millis())
 916                                                .into(),
 917                                        ),
 918                                    ],
 919                                },
 920                            ))
 921                            .ok();
 922                    }
 923                }
 924            }
 925        }
 926    }
 927
 928    pub fn debug_info(
 929        &mut self,
 930        project: &Entity<Project>,
 931        cx: &mut Context<Self>,
 932    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 933        let project_state = self.get_or_init_project(project, cx);
 934        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 935        project_state.debug_tx = Some(debug_watch_tx);
 936        debug_watch_rx
 937    }
 938
 939    fn handle_project_event(
 940        &mut self,
 941        project: Entity<Project>,
 942        event: &project::Event,
 943        cx: &mut Context<Self>,
 944    ) {
 945        // TODO [zeta2] init with recent paths
 946        match event {
 947            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 948                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 949                    return;
 950                };
 951                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 952                if let Some(path) = path {
 953                    if let Some(ix) = project_state
 954                        .recent_paths
 955                        .iter()
 956                        .position(|probe| probe == &path)
 957                    {
 958                        project_state.recent_paths.remove(ix);
 959                    }
 960                    project_state.recent_paths.push_front(path);
 961                }
 962            }
 963            project::Event::DiagnosticsUpdated { .. } => {
 964                if cx.has_flag::<Zeta2FeatureFlag>() {
 965                    self.refresh_prediction_from_diagnostics(project, cx);
 966                }
 967            }
 968            _ => (),
 969        }
 970    }
 971
 972    fn register_buffer_impl<'a>(
 973        project_state: &'a mut ProjectState,
 974        buffer: &Entity<Buffer>,
 975        project: &Entity<Project>,
 976        cx: &mut Context<Self>,
 977    ) -> &'a mut RegisteredBuffer {
 978        let buffer_id = buffer.entity_id();
 979
 980        if let Some(file) = buffer.read(cx).file() {
 981            let worktree_id = file.worktree_id(cx);
 982            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 983                project_state
 984                    .license_detection_watchers
 985                    .entry(worktree_id)
 986                    .or_insert_with(|| {
 987                        let project_entity_id = project.entity_id();
 988                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 989                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 990                            else {
 991                                return;
 992                            };
 993                            project_state
 994                                .license_detection_watchers
 995                                .remove(&worktree_id);
 996                        })
 997                        .detach();
 998                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 999                    });
1000            }
1001        }
1002
1003        match project_state.registered_buffers.entry(buffer_id) {
1004            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1005            hash_map::Entry::Vacant(entry) => {
1006                let buf = buffer.read(cx);
1007                let snapshot = buf.text_snapshot();
1008                let file = buf.file().cloned();
1009                let project_entity_id = project.entity_id();
1010                entry.insert(RegisteredBuffer {
1011                    snapshot,
1012                    file,
1013                    last_position: None,
1014                    _subscriptions: [
1015                        cx.subscribe(buffer, {
1016                            let project = project.downgrade();
1017                            move |this, buffer, event, cx| {
1018                                if let language::BufferEvent::Edited = event
1019                                    && let Some(project) = project.upgrade()
1020                                {
1021                                    this.report_changes_for_buffer(&buffer, &project, cx);
1022                                }
1023                            }
1024                        }),
1025                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1026                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1027                            else {
1028                                return;
1029                            };
1030                            project_state.registered_buffers.remove(&buffer_id);
1031                        }),
1032                    ],
1033                })
1034            }
1035        }
1036    }
1037
1038    fn report_changes_for_buffer(
1039        &mut self,
1040        buffer: &Entity<Buffer>,
1041        project: &Entity<Project>,
1042        cx: &mut Context<Self>,
1043    ) {
1044        let project_state = self.get_or_init_project(project, cx);
1045        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1046
1047        let buf = buffer.read(cx);
1048        let new_file = buf.file().cloned();
1049        let new_snapshot = buf.text_snapshot();
1050        if new_snapshot.version == registered_buffer.snapshot.version {
1051            return;
1052        }
1053
1054        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1055        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1056        let mut num_edits = 0usize;
1057        let mut total_deleted = 0usize;
1058        let mut total_inserted = 0usize;
1059        let mut edit_range: Option<Range<Anchor>> = None;
1060        let mut last_offset: Option<usize> = None;
1061
1062        for (edit, anchor_range) in
1063            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1064        {
1065            num_edits += 1;
1066            total_deleted += edit.old.len();
1067            total_inserted += edit.new.len();
1068            edit_range = Some(match edit_range {
1069                None => anchor_range,
1070                Some(acc) => acc.start..anchor_range.end,
1071            });
1072            last_offset = Some(edit.new.end);
1073        }
1074
1075        if num_edits > 0 {
1076            let action_type = match (total_deleted, total_inserted, num_edits) {
1077                (0, ins, n) if ins == n => UserActionType::InsertChar,
1078                (0, _, _) => UserActionType::InsertSelection,
1079                (del, 0, n) if del == n => UserActionType::DeleteChar,
1080                (_, 0, _) => UserActionType::DeleteSelection,
1081                (_, ins, n) if ins == n => UserActionType::InsertChar,
1082                (_, _, _) => UserActionType::InsertSelection,
1083            };
1084
1085            if let Some(offset) = last_offset {
1086                let point = new_snapshot.offset_to_point(offset);
1087                let timestamp_epoch_ms = SystemTime::now()
1088                    .duration_since(UNIX_EPOCH)
1089                    .map(|d| d.as_millis() as u64)
1090                    .unwrap_or(0);
1091                project_state.record_user_action(UserActionRecord {
1092                    action_type,
1093                    buffer_id: buffer.entity_id(),
1094                    line_number: point.row,
1095                    offset,
1096                    timestamp_epoch_ms,
1097                });
1098            }
1099        }
1100
1101        let events = &mut project_state.events;
1102
1103        let now = cx.background_executor().now();
1104        if let Some(last_event) = project_state.last_event.as_mut() {
1105            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1106                == last_event.new_snapshot.remote_id()
1107                && old_snapshot.version == last_event.new_snapshot.version;
1108
1109            let should_coalesce = is_next_snapshot_of_same_buffer
1110                && edit_range
1111                    .as_ref()
1112                    .zip(last_event.edit_range.as_ref())
1113                    .is_some_and(|(a, b)| {
1114                        let a = a.to_point(&new_snapshot);
1115                        let b = b.to_point(&new_snapshot);
1116                        if a.start > b.end {
1117                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1118                        } else if b.start > a.end {
1119                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1120                        } else {
1121                            true
1122                        }
1123                    });
1124
1125            if should_coalesce {
1126                let pause_elapsed = last_event
1127                    .last_edit_time
1128                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1129                    .unwrap_or(false);
1130                if pause_elapsed {
1131                    last_event.snapshot_after_last_editing_pause =
1132                        Some(last_event.new_snapshot.clone());
1133                }
1134
1135                last_event.edit_range = edit_range;
1136                last_event.new_snapshot = new_snapshot;
1137                last_event.last_edit_time = Some(now);
1138                return;
1139            }
1140        }
1141
1142        if let Some(event) = project_state.last_event.take() {
1143            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1144                if events.len() + 1 >= EVENT_COUNT_MAX {
1145                    events.pop_front();
1146                }
1147                events.push_back(event);
1148            }
1149        }
1150
1151        project_state.last_event = Some(LastEvent {
1152            old_file,
1153            new_file,
1154            old_snapshot,
1155            new_snapshot,
1156            edit_range,
1157            snapshot_after_last_editing_pause: None,
1158            last_edit_time: Some(now),
1159        });
1160    }
1161
1162    fn prediction_at(
1163        &mut self,
1164        buffer: &Entity<Buffer>,
1165        position: Option<language::Anchor>,
1166        project: &Entity<Project>,
1167        cx: &App,
1168    ) -> Option<BufferEditPrediction<'_>> {
1169        let project_state = self.projects.get_mut(&project.entity_id())?;
1170        if let Some(position) = position
1171            && let Some(buffer) = project_state
1172                .registered_buffers
1173                .get_mut(&buffer.entity_id())
1174        {
1175            buffer.last_position = Some(position);
1176        }
1177
1178        let CurrentEditPrediction {
1179            requested_by,
1180            prediction,
1181            ..
1182        } = project_state.current_prediction.as_ref()?;
1183
1184        if prediction.targets_buffer(buffer.read(cx)) {
1185            Some(BufferEditPrediction::Local { prediction })
1186        } else {
1187            let show_jump = match requested_by {
1188                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1189                    requested_by_buffer_id == &buffer.entity_id()
1190                }
1191                PredictionRequestedBy::DiagnosticsUpdate => true,
1192            };
1193
1194            if show_jump {
1195                Some(BufferEditPrediction::Jump { prediction })
1196            } else {
1197                None
1198            }
1199        }
1200    }
1201
1202    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1203        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1204            return;
1205        };
1206
1207        let Some(current_prediction) = project_state.current_prediction.take() else {
1208            return;
1209        };
1210
1211        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1212            project_state.cancel_pending_prediction(pending_prediction, cx);
1213        }
1214
1215        match self.edit_prediction_model {
1216            EditPredictionModel::Sweep => {
1217                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1218            }
1219            EditPredictionModel::Mercury => {
1220                mercury::edit_prediction_accepted(
1221                    current_prediction.prediction.id,
1222                    self.client.http_client(),
1223                    cx,
1224                );
1225            }
1226            EditPredictionModel::Ollama => {}
1227            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1228                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1229            }
1230        }
1231    }
1232
1233    async fn handle_rejected_predictions(
1234        rx: UnboundedReceiver<EditPredictionRejection>,
1235        client: Arc<Client>,
1236        llm_token: LlmApiToken,
1237        app_version: Version,
1238        background_executor: BackgroundExecutor,
1239    ) {
1240        let mut rx = std::pin::pin!(rx.peekable());
1241        let mut batched = Vec::new();
1242
1243        while let Some(rejection) = rx.next().await {
1244            batched.push(rejection);
1245
1246            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1247                select_biased! {
1248                    next = rx.as_mut().peek().fuse() => {
1249                        if next.is_some() {
1250                            continue;
1251                        }
1252                    }
1253                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1254                }
1255            }
1256
1257            let url = client
1258                .http_client()
1259                .build_zed_llm_url("/predict_edits/reject", &[])
1260                .unwrap();
1261
1262            let flush_count = batched
1263                .len()
1264                // in case items have accumulated after failure
1265                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1266            let start = batched.len() - flush_count;
1267
1268            let body = RejectEditPredictionsBodyRef {
1269                rejections: &batched[start..],
1270            };
1271
1272            let result = Self::send_api_request::<()>(
1273                |builder| {
1274                    let req = builder
1275                        .uri(url.as_ref())
1276                        .body(serde_json::to_string(&body)?.into());
1277                    anyhow::Ok(req?)
1278                },
1279                client.clone(),
1280                llm_token.clone(),
1281                app_version.clone(),
1282                true,
1283            )
1284            .await;
1285
1286            if result.log_err().is_some() {
1287                batched.drain(start..);
1288            }
1289        }
1290    }
1291
1292    fn reject_current_prediction(
1293        &mut self,
1294        reason: EditPredictionRejectReason,
1295        project: &Entity<Project>,
1296        dismiss_reason: edit_prediction_types::EditPredictionDismissReason,
1297        cx: &App,
1298    ) {
1299        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1300            project_state.pending_predictions.clear();
1301            if let Some(prediction) = project_state.current_prediction.take() {
1302                self.reject_prediction(
1303                    prediction.prediction.id,
1304                    reason,
1305                    prediction.was_shown,
1306                    dismiss_reason,
1307                    cx,
1308                );
1309            }
1310        };
1311    }
1312
1313    fn did_show_current_prediction(
1314        &mut self,
1315        project: &Entity<Project>,
1316        display_type: edit_prediction_types::SuggestionDisplayType,
1317        cx: &mut Context<Self>,
1318    ) {
1319        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1320            return;
1321        };
1322
1323        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1324            return;
1325        };
1326
1327        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1328        let previous_shown_with = current_prediction.shown_with;
1329
1330        if previous_shown_with.is_none() || !is_jump {
1331            current_prediction.shown_with = Some(display_type);
1332        }
1333
1334        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1335
1336        if is_first_non_jump_show {
1337            current_prediction.was_shown = true;
1338        }
1339
1340        let display_type_changed = previous_shown_with != Some(display_type);
1341
1342        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1343            sweep_ai::edit_prediction_shown(
1344                &self.sweep_ai,
1345                self.client.clone(),
1346                &current_prediction.prediction,
1347                display_type,
1348                cx,
1349            );
1350        }
1351
1352        if is_first_non_jump_show {
1353            self.shown_predictions
1354                .push_front(current_prediction.prediction.clone());
1355            if self.shown_predictions.len() > 50 {
1356                let completion = self.shown_predictions.pop_back().unwrap();
1357                self.rated_predictions.remove(&completion.id);
1358            }
1359        }
1360    }
1361
1362    fn reject_prediction(
1363        &mut self,
1364        prediction_id: EditPredictionId,
1365        reason: EditPredictionRejectReason,
1366        was_shown: bool,
1367        dismiss_reason: edit_prediction_types::EditPredictionDismissReason,
1368        cx: &App,
1369    ) {
1370        match self.edit_prediction_model {
1371            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1372                if self.custom_predict_edits_url.is_none() {
1373                    self.reject_predictions_tx
1374                        .unbounded_send(EditPredictionRejection {
1375                            request_id: prediction_id.to_string(),
1376                            reason,
1377                            was_shown,
1378                        })
1379                        .log_err();
1380                }
1381            }
1382            EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1383            EditPredictionModel::Mercury => {
1384                mercury::edit_prediction_rejected(
1385                    prediction_id,
1386                    was_shown,
1387                    dismiss_reason,
1388                    self.client.http_client(),
1389                    cx,
1390                );
1391            }
1392        }
1393    }
1394
1395    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1396        self.projects
1397            .get(&project.entity_id())
1398            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1399    }
1400
1401    pub fn refresh_prediction_from_buffer(
1402        &mut self,
1403        project: Entity<Project>,
1404        buffer: Entity<Buffer>,
1405        position: language::Anchor,
1406        cx: &mut Context<Self>,
1407    ) {
1408        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1409            let Some(request_task) = this
1410                .update(cx, |this, cx| {
1411                    this.request_prediction(
1412                        &project,
1413                        &buffer,
1414                        position,
1415                        PredictEditsRequestTrigger::Other,
1416                        cx,
1417                    )
1418                })
1419                .log_err()
1420            else {
1421                return Task::ready(anyhow::Ok(None));
1422            };
1423
1424            cx.spawn(async move |_cx| {
1425                request_task.await.map(|prediction_result| {
1426                    prediction_result.map(|prediction_result| {
1427                        (
1428                            prediction_result,
1429                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1430                        )
1431                    })
1432                })
1433            })
1434        })
1435    }
1436
1437    pub fn refresh_prediction_from_diagnostics(
1438        &mut self,
1439        project: Entity<Project>,
1440        cx: &mut Context<Self>,
1441    ) {
1442        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1443            return;
1444        };
1445
1446        // Prefer predictions from buffer
1447        if project_state.current_prediction.is_some() {
1448            return;
1449        };
1450
1451        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1452            let Some((active_buffer, snapshot, cursor_point)) = this
1453                .read_with(cx, |this, cx| {
1454                    let project_state = this.projects.get(&project.entity_id())?;
1455                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1456                    let snapshot = buffer.read(cx).snapshot();
1457
1458                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1459                        return None;
1460                    }
1461
1462                    let cursor_point = position
1463                        .map(|pos| pos.to_point(&snapshot))
1464                        .unwrap_or_default();
1465
1466                    Some((buffer, snapshot, cursor_point))
1467                })
1468                .log_err()
1469                .flatten()
1470            else {
1471                return Task::ready(anyhow::Ok(None));
1472            };
1473
1474            cx.spawn(async move |cx| {
1475                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1476                    active_buffer,
1477                    &snapshot,
1478                    Default::default(),
1479                    cursor_point,
1480                    &project,
1481                    cx,
1482                )
1483                .await?
1484                else {
1485                    return anyhow::Ok(None);
1486                };
1487
1488                let Some(prediction_result) = this
1489                    .update(cx, |this, cx| {
1490                        this.request_prediction(
1491                            &project,
1492                            &jump_buffer,
1493                            jump_position,
1494                            PredictEditsRequestTrigger::Diagnostics,
1495                            cx,
1496                        )
1497                    })?
1498                    .await?
1499                else {
1500                    return anyhow::Ok(None);
1501                };
1502
1503                this.update(cx, |this, cx| {
1504                    Some((
1505                        if this
1506                            .get_or_init_project(&project, cx)
1507                            .current_prediction
1508                            .is_none()
1509                        {
1510                            prediction_result
1511                        } else {
1512                            EditPredictionResult {
1513                                id: prediction_result.id,
1514                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1515                            }
1516                        },
1517                        PredictionRequestedBy::DiagnosticsUpdate,
1518                    ))
1519                })
1520            })
1521        });
1522    }
1523
1524    fn predictions_enabled_at(
1525        snapshot: &BufferSnapshot,
1526        position: Option<language::Anchor>,
1527        cx: &App,
1528    ) -> bool {
1529        let file = snapshot.file();
1530        let all_settings = all_language_settings(file, cx);
1531        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1532            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1533        {
1534            return false;
1535        }
1536
1537        if let Some(last_position) = position {
1538            let settings = snapshot.settings_at(last_position, cx);
1539
1540            if !settings.edit_predictions_disabled_in.is_empty()
1541                && let Some(scope) = snapshot.language_scope_at(last_position)
1542                && let Some(scope_name) = scope.override_name()
1543                && settings
1544                    .edit_predictions_disabled_in
1545                    .iter()
1546                    .any(|s| s == scope_name)
1547            {
1548                return false;
1549            }
1550        }
1551
1552        true
1553    }
1554
1555    #[cfg(not(test))]
1556    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1557    #[cfg(test)]
1558    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1559
1560    fn queue_prediction_refresh(
1561        &mut self,
1562        project: Entity<Project>,
1563        throttle_entity: EntityId,
1564        cx: &mut Context<Self>,
1565        do_refresh: impl FnOnce(
1566            WeakEntity<Self>,
1567            &mut AsyncApp,
1568        )
1569            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1570        + 'static,
1571    ) {
1572        let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1573        let drop_on_cancel = is_ollama;
1574        let max_pending_predictions = if is_ollama { 1 } else { 2 };
1575        let project_state = self.get_or_init_project(&project, cx);
1576        let pending_prediction_id = project_state.next_pending_prediction_id;
1577        project_state.next_pending_prediction_id += 1;
1578        let last_request = project_state.last_prediction_refresh;
1579
1580        let task = cx.spawn(async move |this, cx| {
1581            if let Some((last_entity, last_timestamp)) = last_request
1582                && throttle_entity == last_entity
1583                && let Some(timeout) =
1584                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1585            {
1586                cx.background_executor().timer(timeout).await;
1587            }
1588
1589            // If this task was cancelled before the throttle timeout expired,
1590            // do not perform a request.
1591            let mut is_cancelled = true;
1592            this.update(cx, |this, cx| {
1593                let project_state = this.get_or_init_project(&project, cx);
1594                if !project_state
1595                    .cancelled_predictions
1596                    .remove(&pending_prediction_id)
1597                {
1598                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1599                    is_cancelled = false;
1600                }
1601            })
1602            .ok();
1603            if is_cancelled {
1604                return None;
1605            }
1606
1607            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1608            let new_prediction_id = new_prediction_result
1609                .as_ref()
1610                .map(|(prediction, _)| prediction.id.clone());
1611
1612            // When a prediction completes, remove it from the pending list, and cancel
1613            // any pending predictions that were enqueued before it.
1614            this.update(cx, |this, cx| {
1615                let project_state = this.get_or_init_project(&project, cx);
1616
1617                let is_cancelled = project_state
1618                    .cancelled_predictions
1619                    .remove(&pending_prediction_id);
1620
1621                let new_current_prediction = if !is_cancelled
1622                    && let Some((prediction_result, requested_by)) = new_prediction_result
1623                {
1624                    match prediction_result.prediction {
1625                        Ok(prediction) => {
1626                            let new_prediction = CurrentEditPrediction {
1627                                requested_by,
1628                                prediction,
1629                                was_shown: false,
1630                                shown_with: None,
1631                            };
1632
1633                            if let Some(current_prediction) =
1634                                project_state.current_prediction.as_ref()
1635                            {
1636                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1637                                {
1638                                    this.reject_current_prediction(
1639                                        EditPredictionRejectReason::Replaced,
1640                                        &project,
1641                                        edit_prediction_types::EditPredictionDismissReason::Ignored,
1642                                        cx,
1643                                    );
1644
1645                                    Some(new_prediction)
1646                                } else {
1647                                    this.reject_prediction(
1648                                        new_prediction.prediction.id,
1649                                        EditPredictionRejectReason::CurrentPreferred,
1650                                        false,
1651                                        edit_prediction_types::EditPredictionDismissReason::Ignored,
1652                                        cx,
1653                                    );
1654                                    None
1655                                }
1656                            } else {
1657                                Some(new_prediction)
1658                            }
1659                        }
1660                        Err(reject_reason) => {
1661                            this.reject_prediction(
1662                                prediction_result.id,
1663                                reject_reason,
1664                                false,
1665                                edit_prediction_types::EditPredictionDismissReason::Ignored,
1666                                cx,
1667                            );
1668                            None
1669                        }
1670                    }
1671                } else {
1672                    None
1673                };
1674
1675                let project_state = this.get_or_init_project(&project, cx);
1676
1677                if let Some(new_prediction) = new_current_prediction {
1678                    project_state.current_prediction = Some(new_prediction);
1679                }
1680
1681                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1682                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1683                    if pending_prediction.id == pending_prediction_id {
1684                        pending_predictions.remove(ix);
1685                        for pending_prediction in pending_predictions.drain(0..ix) {
1686                            project_state.cancel_pending_prediction(pending_prediction, cx)
1687                        }
1688                        break;
1689                    }
1690                }
1691                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1692                cx.notify();
1693            })
1694            .ok();
1695
1696            new_prediction_id
1697        });
1698
1699        if project_state.pending_predictions.len() < max_pending_predictions {
1700            project_state.pending_predictions.push(PendingPrediction {
1701                id: pending_prediction_id,
1702                task,
1703                drop_on_cancel,
1704            });
1705        } else {
1706            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1707            project_state.pending_predictions.push(PendingPrediction {
1708                id: pending_prediction_id,
1709                task,
1710                drop_on_cancel,
1711            });
1712            project_state.cancel_pending_prediction(pending_prediction, cx);
1713        }
1714    }
1715
1716    pub fn request_prediction(
1717        &mut self,
1718        project: &Entity<Project>,
1719        active_buffer: &Entity<Buffer>,
1720        position: language::Anchor,
1721        trigger: PredictEditsRequestTrigger,
1722        cx: &mut Context<Self>,
1723    ) -> Task<Result<Option<EditPredictionResult>>> {
1724        self.request_prediction_internal(
1725            project.clone(),
1726            active_buffer.clone(),
1727            position,
1728            trigger,
1729            cx.has_flag::<Zeta2FeatureFlag>(),
1730            cx,
1731        )
1732    }
1733
1734    fn request_prediction_internal(
1735        &mut self,
1736        project: Entity<Project>,
1737        active_buffer: Entity<Buffer>,
1738        position: language::Anchor,
1739        trigger: PredictEditsRequestTrigger,
1740        allow_jump: bool,
1741        cx: &mut Context<Self>,
1742    ) -> Task<Result<Option<EditPredictionResult>>> {
1743        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1744
1745        self.get_or_init_project(&project, cx);
1746        let project_state = self.projects.get(&project.entity_id()).unwrap();
1747        let stored_events = project_state.events_split_by_pause(cx);
1748        let has_events = !stored_events.is_empty();
1749        let events: Vec<Arc<zeta_prompt::Event>> =
1750            stored_events.into_iter().map(|e| e.event).collect();
1751        let debug_tx = project_state.debug_tx.clone();
1752
1753        let snapshot = active_buffer.read(cx).snapshot();
1754        let cursor_point = position.to_point(&snapshot);
1755        let current_offset = position.to_offset(&snapshot);
1756
1757        let mut user_actions: Vec<UserActionRecord> =
1758            project_state.user_actions.iter().cloned().collect();
1759
1760        if let Some(last_action) = user_actions.last() {
1761            if last_action.buffer_id == active_buffer.entity_id()
1762                && current_offset != last_action.offset
1763            {
1764                let timestamp_epoch_ms = SystemTime::now()
1765                    .duration_since(UNIX_EPOCH)
1766                    .map(|d| d.as_millis() as u64)
1767                    .unwrap_or(0);
1768                user_actions.push(UserActionRecord {
1769                    action_type: UserActionType::CursorMovement,
1770                    buffer_id: active_buffer.entity_id(),
1771                    line_number: cursor_point.row,
1772                    offset: current_offset,
1773                    timestamp_epoch_ms,
1774                });
1775            }
1776        }
1777        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1778        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1779        let diagnostic_search_range =
1780            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1781
1782        let related_files = self.context_for_project(&project, cx);
1783
1784        let inputs = EditPredictionModelInput {
1785            project: project.clone(),
1786            buffer: active_buffer.clone(),
1787            snapshot: snapshot.clone(),
1788            position,
1789            events,
1790            related_files,
1791            recent_paths: project_state.recent_paths.clone(),
1792            trigger,
1793            diagnostic_search_range: diagnostic_search_range.clone(),
1794            debug_tx,
1795            user_actions,
1796        };
1797
1798        let can_collect_example = snapshot
1799            .file()
1800            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1801            && self.can_collect_events(&inputs.events, cx)
1802            && self.can_collect_related_files(&project, cx);
1803
1804        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1805            let events_for_capture =
1806                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1807            let related_files_for_capture = inputs.related_files.clone();
1808            if let Some(example_task) = capture_example::capture_example(
1809                project.clone(),
1810                active_buffer.clone(),
1811                position,
1812                events_for_capture,
1813                related_files_for_capture,
1814                false,
1815                cx,
1816            ) {
1817                cx.spawn(async move |_this, _cx| {
1818                    let example = example_task.await?;
1819                    telemetry::event!("Edit Prediction Example Captured", example = example);
1820                    anyhow::Ok(())
1821                })
1822                .detach_and_log_err(cx);
1823            }
1824        }
1825        let task = match self.edit_prediction_model {
1826            EditPredictionModel::Zeta1 => {
1827                if should_send_testing_zeta2_request() {
1828                    let mut zeta2_inputs = inputs.clone();
1829                    zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1830                    zeta2::request_prediction_with_zeta2(
1831                        self,
1832                        zeta2_inputs,
1833                        Default::default(),
1834                        cx,
1835                    )
1836                    .detach();
1837                }
1838                zeta1::request_prediction_with_zeta1(self, inputs, cx)
1839            }
1840            EditPredictionModel::Zeta2 { version } => {
1841                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1842            }
1843            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1844            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1845            EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1846        };
1847
1848        cx.spawn(async move |this, cx| {
1849            let prediction = task.await?;
1850
1851            if prediction.is_none() && allow_jump {
1852                let cursor_point = position.to_point(&snapshot);
1853                if has_events
1854                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1855                        active_buffer.clone(),
1856                        &snapshot,
1857                        diagnostic_search_range,
1858                        cursor_point,
1859                        &project,
1860                        cx,
1861                    )
1862                    .await?
1863                {
1864                    return this
1865                        .update(cx, |this, cx| {
1866                            this.request_prediction_internal(
1867                                project,
1868                                jump_buffer,
1869                                jump_position,
1870                                trigger,
1871                                false,
1872                                cx,
1873                            )
1874                        })?
1875                        .await;
1876                }
1877
1878                return anyhow::Ok(None);
1879            }
1880
1881            Ok(prediction)
1882        })
1883    }
1884
1885    async fn next_diagnostic_location(
1886        active_buffer: Entity<Buffer>,
1887        active_buffer_snapshot: &BufferSnapshot,
1888        active_buffer_diagnostic_search_range: Range<Point>,
1889        active_buffer_cursor_point: Point,
1890        project: &Entity<Project>,
1891        cx: &mut AsyncApp,
1892    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1893        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1894        let mut jump_location = active_buffer_snapshot
1895            .diagnostic_groups(None)
1896            .into_iter()
1897            .filter_map(|(_, group)| {
1898                let range = &group.entries[group.primary_ix]
1899                    .range
1900                    .to_point(&active_buffer_snapshot);
1901                if range.overlaps(&active_buffer_diagnostic_search_range) {
1902                    None
1903                } else {
1904                    Some(range.start)
1905                }
1906            })
1907            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1908            .map(|position| {
1909                (
1910                    active_buffer.clone(),
1911                    active_buffer_snapshot.anchor_before(position),
1912                )
1913            });
1914
1915        if jump_location.is_none() {
1916            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1917                let file = buffer.file()?;
1918
1919                Some(ProjectPath {
1920                    worktree_id: file.worktree_id(cx),
1921                    path: file.path().clone(),
1922                })
1923            });
1924
1925            let buffer_task = project.update(cx, |project, cx| {
1926                let (path, _, _) = project
1927                    .diagnostic_summaries(false, cx)
1928                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1929                    .max_by_key(|(path, _, _)| {
1930                        // find the buffer with errors that shares most parent directories
1931                        path.path
1932                            .components()
1933                            .zip(
1934                                active_buffer_path
1935                                    .as_ref()
1936                                    .map(|p| p.path.components())
1937                                    .unwrap_or_default(),
1938                            )
1939                            .take_while(|(a, b)| a == b)
1940                            .count()
1941                    })?;
1942
1943                Some(project.open_buffer(path, cx))
1944            });
1945
1946            if let Some(buffer_task) = buffer_task {
1947                let closest_buffer = buffer_task.await?;
1948
1949                jump_location = closest_buffer
1950                    .read_with(cx, |buffer, _cx| {
1951                        buffer
1952                            .buffer_diagnostics(None)
1953                            .into_iter()
1954                            .min_by_key(|entry| entry.diagnostic.severity)
1955                            .map(|entry| entry.range.start)
1956                    })
1957                    .map(|position| (closest_buffer, position));
1958            }
1959        }
1960
1961        anyhow::Ok(jump_location)
1962    }
1963
1964    async fn send_raw_llm_request(
1965        request: RawCompletionRequest,
1966        client: Arc<Client>,
1967        custom_url: Option<Arc<Url>>,
1968        llm_token: LlmApiToken,
1969        app_version: Version,
1970    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1971        let url = if let Some(custom_url) = custom_url {
1972            custom_url.as_ref().clone()
1973        } else {
1974            client
1975                .http_client()
1976                .build_zed_llm_url("/predict_edits/raw", &[])?
1977        };
1978
1979        Self::send_api_request(
1980            |builder| {
1981                let req = builder
1982                    .uri(url.as_ref())
1983                    .body(serde_json::to_string(&request)?.into());
1984                Ok(req?)
1985            },
1986            client,
1987            llm_token,
1988            app_version,
1989            true,
1990        )
1991        .await
1992    }
1993
1994    pub(crate) async fn send_v3_request(
1995        input: ZetaPromptInput,
1996        prompt_version: ZetaVersion,
1997        client: Arc<Client>,
1998        llm_token: LlmApiToken,
1999        app_version: Version,
2000        trigger: PredictEditsRequestTrigger,
2001    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2002        let url = client
2003            .http_client()
2004            .build_zed_llm_url("/predict_edits/v3", &[])?;
2005
2006        let request = PredictEditsV3Request {
2007            input,
2008            model: EDIT_PREDICTIONS_MODEL_ID.clone(),
2009            prompt_version,
2010            trigger,
2011        };
2012
2013        Self::send_api_request(
2014            |builder| {
2015                let req = builder
2016                    .uri(url.as_ref())
2017                    .body(serde_json::to_string(&request)?.into());
2018                Ok(req?)
2019            },
2020            client,
2021            llm_token,
2022            app_version,
2023            true,
2024        )
2025        .await
2026    }
2027
2028    fn handle_api_response<T>(
2029        this: &WeakEntity<Self>,
2030        response: Result<(T, Option<EditPredictionUsage>)>,
2031        cx: &mut gpui::AsyncApp,
2032    ) -> Result<T> {
2033        match response {
2034            Ok((data, usage)) => {
2035                if let Some(usage) = usage {
2036                    this.update(cx, |this, cx| {
2037                        this.user_store.update(cx, |user_store, cx| {
2038                            user_store.update_edit_prediction_usage(usage, cx);
2039                        });
2040                    })
2041                    .ok();
2042                }
2043                Ok(data)
2044            }
2045            Err(err) => {
2046                if err.is::<ZedUpdateRequiredError>() {
2047                    cx.update(|cx| {
2048                        this.update(cx, |this, _cx| {
2049                            this.update_required = true;
2050                        })
2051                        .ok();
2052
2053                        let error_message: SharedString = err.to_string().into();
2054                        show_app_notification(
2055                            NotificationId::unique::<ZedUpdateRequiredError>(),
2056                            cx,
2057                            move |cx| {
2058                                cx.new(|cx| {
2059                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2060                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2061                                })
2062                            },
2063                        );
2064                    });
2065                }
2066                Err(err)
2067            }
2068        }
2069    }
2070
2071    async fn send_api_request<Res>(
2072        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2073        client: Arc<Client>,
2074        llm_token: LlmApiToken,
2075        app_version: Version,
2076        require_auth: bool,
2077    ) -> Result<(Res, Option<EditPredictionUsage>)>
2078    where
2079        Res: DeserializeOwned,
2080    {
2081        let http_client = client.http_client();
2082
2083        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2084            Some(custom_token)
2085        } else if require_auth {
2086            Some(llm_token.acquire(&client).await?)
2087        } else {
2088            llm_token.acquire(&client).await.ok()
2089        };
2090        let mut did_retry = false;
2091
2092        loop {
2093            let request_builder = http_client::Request::builder().method(Method::POST);
2094
2095            let mut request_builder = request_builder
2096                .header("Content-Type", "application/json")
2097                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2098
2099            // Only add Authorization header if we have a token
2100            if let Some(ref token_value) = token {
2101                request_builder =
2102                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2103            }
2104
2105            let request = build(request_builder)?;
2106
2107            let mut response = http_client.send(request).await?;
2108
2109            if let Some(minimum_required_version) = response
2110                .headers()
2111                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2112                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2113            {
2114                anyhow::ensure!(
2115                    app_version >= minimum_required_version,
2116                    ZedUpdateRequiredError {
2117                        minimum_version: minimum_required_version
2118                    }
2119                );
2120            }
2121
2122            if response.status().is_success() {
2123                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2124
2125                let mut body = Vec::new();
2126                response.body_mut().read_to_end(&mut body).await?;
2127                return Ok((serde_json::from_slice(&body)?, usage));
2128            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2129                did_retry = true;
2130                token = Some(llm_token.refresh(&client).await?);
2131            } else {
2132                let mut body = String::new();
2133                response.body_mut().read_to_string(&mut body).await?;
2134                anyhow::bail!(
2135                    "Request failed with status: {:?}\nBody: {}",
2136                    response.status(),
2137                    body
2138                );
2139            }
2140        }
2141    }
2142
2143    pub fn refresh_context(
2144        &mut self,
2145        project: &Entity<Project>,
2146        buffer: &Entity<language::Buffer>,
2147        cursor_position: language::Anchor,
2148        cx: &mut Context<Self>,
2149    ) {
2150        self.get_or_init_project(project, cx)
2151            .context
2152            .update(cx, |store, cx| {
2153                store.refresh(buffer.clone(), cursor_position, cx);
2154            });
2155    }
2156
2157    #[cfg(feature = "cli-support")]
2158    pub fn set_context_for_buffer(
2159        &mut self,
2160        project: &Entity<Project>,
2161        related_files: Vec<RelatedFile>,
2162        cx: &mut Context<Self>,
2163    ) {
2164        self.get_or_init_project(project, cx)
2165            .context
2166            .update(cx, |store, cx| {
2167                store.set_related_files(related_files, cx);
2168            });
2169    }
2170
2171    #[cfg(feature = "cli-support")]
2172    pub fn set_recent_paths_for_project(
2173        &mut self,
2174        project: &Entity<Project>,
2175        paths: impl IntoIterator<Item = project::ProjectPath>,
2176        cx: &mut Context<Self>,
2177    ) {
2178        let project_state = self.get_or_init_project(project, cx);
2179        project_state.recent_paths = paths.into_iter().collect();
2180    }
2181
2182    fn is_file_open_source(
2183        &self,
2184        project: &Entity<Project>,
2185        file: &Arc<dyn File>,
2186        cx: &App,
2187    ) -> bool {
2188        if !file.is_local() || file.is_private() {
2189            return false;
2190        }
2191        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2192            return false;
2193        };
2194        project_state
2195            .license_detection_watchers
2196            .get(&file.worktree_id(cx))
2197            .as_ref()
2198            .is_some_and(|watcher| watcher.is_project_open_source())
2199    }
2200
2201    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2202        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2203    }
2204
2205    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2206        if !self.data_collection_choice.is_enabled(cx) {
2207            return false;
2208        }
2209        events.iter().all(|event| {
2210            matches!(
2211                event.as_ref(),
2212                zeta_prompt::Event::BufferChange {
2213                    in_open_source_repo: true,
2214                    ..
2215                }
2216            )
2217        })
2218    }
2219
2220    fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2221        if !self.data_collection_choice.is_enabled(cx) {
2222            return false;
2223        }
2224
2225        let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2226
2227        related_with_buffers.iter().all(|(_, buffer)| {
2228            buffer
2229                .read(cx)
2230                .file()
2231                .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2232        })
2233    }
2234
2235    fn load_data_collection_choice() -> DataCollectionChoice {
2236        let choice = KEY_VALUE_STORE
2237            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2238            .log_err()
2239            .flatten();
2240
2241        match choice.as_deref() {
2242            Some("true") => DataCollectionChoice::Enabled,
2243            Some("false") => DataCollectionChoice::Disabled,
2244            Some(_) => {
2245                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2246                DataCollectionChoice::NotAnswered
2247            }
2248            None => DataCollectionChoice::NotAnswered,
2249        }
2250    }
2251
2252    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2253        self.data_collection_choice = self.data_collection_choice.toggle();
2254        let new_choice = self.data_collection_choice;
2255        let is_enabled = new_choice.is_enabled(cx);
2256        db::write_and_log(cx, move || {
2257            KEY_VALUE_STORE.write_kvp(
2258                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2259                is_enabled.to_string(),
2260            )
2261        });
2262    }
2263
2264    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2265        self.shown_predictions.iter()
2266    }
2267
2268    pub fn shown_completions_len(&self) -> usize {
2269        self.shown_predictions.len()
2270    }
2271
2272    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2273        self.rated_predictions.contains(id)
2274    }
2275
2276    pub fn rate_prediction(
2277        &mut self,
2278        prediction: &EditPrediction,
2279        rating: EditPredictionRating,
2280        feedback: String,
2281        cx: &mut Context<Self>,
2282    ) {
2283        self.rated_predictions.insert(prediction.id.clone());
2284        telemetry::event!(
2285            "Edit Prediction Rated",
2286            rating,
2287            inputs = prediction.inputs,
2288            output = prediction
2289                .edit_preview
2290                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2291            feedback
2292        );
2293        self.client.telemetry().flush_events().detach();
2294        cx.notify();
2295    }
2296}
2297
2298pub(crate) fn filter_redundant_excerpts(
2299    mut related_files: Vec<RelatedFile>,
2300    cursor_path: &Path,
2301    cursor_row_range: Range<u32>,
2302) -> Vec<RelatedFile> {
2303    for file in &mut related_files {
2304        if file.path.as_ref() == cursor_path {
2305            file.excerpts.retain(|excerpt| {
2306                excerpt.row_range.start < cursor_row_range.start
2307                    || excerpt.row_range.end > cursor_row_range.end
2308            });
2309        }
2310    }
2311    related_files.retain(|file| !file.excerpts.is_empty());
2312    related_files
2313}
2314
2315#[derive(Error, Debug)]
2316#[error(
2317    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2318)]
2319pub struct ZedUpdateRequiredError {
2320    minimum_version: Version,
2321}
2322
2323#[derive(Debug, Clone, Copy)]
2324pub enum DataCollectionChoice {
2325    NotAnswered,
2326    Enabled,
2327    Disabled,
2328}
2329
2330impl DataCollectionChoice {
2331    pub fn is_enabled(self, cx: &App) -> bool {
2332        if cx.is_staff() {
2333            return true;
2334        }
2335        match self {
2336            Self::Enabled => true,
2337            Self::NotAnswered | Self::Disabled => false,
2338        }
2339    }
2340
2341    #[must_use]
2342    pub fn toggle(&self) -> DataCollectionChoice {
2343        match self {
2344            Self::Enabled => Self::Disabled,
2345            Self::Disabled => Self::Enabled,
2346            Self::NotAnswered => Self::Enabled,
2347        }
2348    }
2349}
2350
2351impl From<bool> for DataCollectionChoice {
2352    fn from(value: bool) -> Self {
2353        match value {
2354            true => DataCollectionChoice::Enabled,
2355            false => DataCollectionChoice::Disabled,
2356        }
2357    }
2358}
2359
2360struct ZedPredictUpsell;
2361
2362impl Dismissable for ZedPredictUpsell {
2363    const KEY: &'static str = "dismissed-edit-predict-upsell";
2364
2365    fn dismissed() -> bool {
2366        // To make this backwards compatible with older versions of Zed, we
2367        // check if the user has seen the previous Edit Prediction Onboarding
2368        // before, by checking the data collection choice which was written to
2369        // the database once the user clicked on "Accept and Enable"
2370        if KEY_VALUE_STORE
2371            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2372            .log_err()
2373            .is_some_and(|s| s.is_some())
2374        {
2375            return true;
2376        }
2377
2378        KEY_VALUE_STORE
2379            .read_kvp(Self::KEY)
2380            .log_err()
2381            .is_some_and(|s| s.is_some())
2382    }
2383}
2384
2385pub fn should_show_upsell_modal() -> bool {
2386    !ZedPredictUpsell::dismissed()
2387}
2388
2389pub fn init(cx: &mut App) {
2390    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2391        workspace.register_action(
2392            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2393                ZedPredictModal::toggle(
2394                    workspace,
2395                    workspace.user_store().clone(),
2396                    workspace.client().clone(),
2397                    window,
2398                    cx,
2399                )
2400            },
2401        );
2402
2403        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2404            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2405                settings
2406                    .project
2407                    .all_languages
2408                    .edit_predictions
2409                    .get_or_insert_default()
2410                    .provider = Some(EditPredictionProvider::None)
2411            });
2412        });
2413        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2414            EditPredictionStore::try_global(cx).and_then(|store| {
2415                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2416            })
2417        }
2418
2419        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2420            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2421                copilot_ui::initiate_sign_in(copilot, window, cx);
2422            }
2423        });
2424        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2425            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2426                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2427            }
2428        });
2429        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2430            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2431                copilot_ui::initiate_sign_out(copilot, window, cx);
2432            }
2433        });
2434    })
2435    .detach();
2436}