edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use text::Edit;
  40use workspace::Workspace;
  41use zeta_prompt::ZetaPromptInput;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59pub mod ollama;
  60mod onboarding_modal;
  61pub mod open_ai_response;
  62mod prediction;
  63pub mod sweep_ai;
  64
  65pub mod udiff;
  66
  67mod capture_example;
  68mod zed_edit_prediction_delegate;
  69pub mod zeta1;
  70pub mod zeta2;
  71
  72#[cfg(test)]
  73mod edit_prediction_tests;
  74
  75use crate::capture_example::{
  76    should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
  77};
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::ollama::Ollama;
  81use crate::onboarding_modal::ZedPredictModal;
  82pub use crate::prediction::EditPrediction;
  83pub use crate::prediction::EditPredictionId;
  84use crate::prediction::EditPredictionResult;
  85pub use crate::sweep_ai::SweepAi;
  86pub use capture_example::capture_example;
  87pub use language_model::ApiKeyState;
  88pub use telemetry_events::EditPredictionRating;
  89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  90
  91actions!(
  92    edit_prediction,
  93    [
  94        /// Resets the edit prediction onboarding state.
  95        ResetOnboarding,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 107
 108static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
 109    LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
 110
 111pub struct Zeta2FeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115
 116    fn enabled_for_staff() -> bool {
 117        true
 118    }
 119}
 120
 121pub struct EditPredictionExampleCaptureFeatureFlag;
 122
 123impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 124    const NAME: &'static str = "edit-prediction-example-capture";
 125
 126    fn enabled_for_staff() -> bool {
 127        true
 128    }
 129}
 130
 131#[derive(Clone)]
 132struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 133
 134impl Global for EditPredictionStoreGlobal {}
 135
 136pub struct EditPredictionStore {
 137    client: Arc<Client>,
 138    user_store: Entity<UserStore>,
 139    llm_token: LlmApiToken,
 140    _llm_token_subscription: Subscription,
 141    projects: HashMap<EntityId, ProjectState>,
 142    update_required: bool,
 143    edit_prediction_model: EditPredictionModel,
 144    pub sweep_ai: SweepAi,
 145    pub mercury: Mercury,
 146    pub ollama: Ollama,
 147    data_collection_choice: DataCollectionChoice,
 148    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    custom_predict_edits_url: Option<Arc<Url>>,
 152}
 153
 154#[derive(Copy, Clone, Default, PartialEq, Eq)]
 155pub enum EditPredictionModel {
 156    #[default]
 157    Zeta1,
 158    Zeta2 {
 159        version: ZetaVersion,
 160    },
 161    Sweep,
 162    Mercury,
 163    Ollama,
 164}
 165
 166#[derive(Clone)]
 167pub struct EditPredictionModelInput {
 168    project: Entity<Project>,
 169    buffer: Entity<Buffer>,
 170    snapshot: BufferSnapshot,
 171    position: Anchor,
 172    events: Vec<Arc<zeta_prompt::Event>>,
 173    related_files: Vec<RelatedFile>,
 174    recent_paths: VecDeque<ProjectPath>,
 175    trigger: PredictEditsRequestTrigger,
 176    diagnostic_search_range: Range<Point>,
 177    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 178    pub user_actions: Vec<UserActionRecord>,
 179}
 180
 181#[derive(Debug)]
 182pub enum DebugEvent {
 183    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 184    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 185    EditPredictionStarted(EditPredictionStartedDebugEvent),
 186    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 187}
 188
 189#[derive(Debug)]
 190pub struct ContextRetrievalStartedDebugEvent {
 191    pub project_entity_id: EntityId,
 192    pub timestamp: Instant,
 193    pub search_prompt: String,
 194}
 195
 196#[derive(Debug)]
 197pub struct ContextRetrievalFinishedDebugEvent {
 198    pub project_entity_id: EntityId,
 199    pub timestamp: Instant,
 200    pub metadata: Vec<(&'static str, SharedString)>,
 201}
 202
 203#[derive(Debug)]
 204pub struct EditPredictionStartedDebugEvent {
 205    pub buffer: WeakEntity<Buffer>,
 206    pub position: Anchor,
 207    pub prompt: Option<String>,
 208}
 209
 210#[derive(Debug)]
 211pub struct EditPredictionFinishedDebugEvent {
 212    pub buffer: WeakEntity<Buffer>,
 213    pub position: Anchor,
 214    pub model_output: Option<String>,
 215}
 216
 217const USER_ACTION_HISTORY_SIZE: usize = 16;
 218
 219#[derive(Clone, Debug)]
 220pub struct UserActionRecord {
 221    pub action_type: UserActionType,
 222    pub buffer_id: EntityId,
 223    pub line_number: u32,
 224    pub offset: usize,
 225    pub timestamp_epoch_ms: u64,
 226}
 227
 228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 229pub enum UserActionType {
 230    InsertChar,
 231    InsertSelection,
 232    DeleteChar,
 233    DeleteSelection,
 234    CursorMovement,
 235}
 236
 237/// An event with associated metadata for reconstructing buffer state.
 238#[derive(Clone)]
 239pub struct StoredEvent {
 240    pub event: Arc<zeta_prompt::Event>,
 241    pub old_snapshot: TextBufferSnapshot,
 242}
 243
 244struct ProjectState {
 245    events: VecDeque<StoredEvent>,
 246    last_event: Option<LastEvent>,
 247    recent_paths: VecDeque<ProjectPath>,
 248    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 249    current_prediction: Option<CurrentEditPrediction>,
 250    next_pending_prediction_id: usize,
 251    pending_predictions: ArrayVec<PendingPrediction, 2>,
 252    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 253    last_prediction_refresh: Option<(EntityId, Instant)>,
 254    cancelled_predictions: HashSet<usize>,
 255    context: Entity<RelatedExcerptStore>,
 256    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 257    user_actions: VecDeque<UserActionRecord>,
 258    _subscriptions: [gpui::Subscription; 2],
 259    copilot: Option<Entity<Copilot>>,
 260}
 261
 262impl ProjectState {
 263    fn record_user_action(&mut self, action: UserActionRecord) {
 264        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 265            self.user_actions.pop_front();
 266        }
 267        self.user_actions.push_back(action);
 268    }
 269
 270    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 271        self.events
 272            .iter()
 273            .cloned()
 274            .chain(
 275                self.last_event
 276                    .as_ref()
 277                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 278            )
 279            .collect()
 280    }
 281
 282    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 283        self.events
 284            .iter()
 285            .cloned()
 286            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 287                let (one, two) = event.split_by_pause();
 288                let one = one.finalize(&self.license_detection_watchers, cx);
 289                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 290                one.into_iter().chain(two)
 291            }))
 292            .collect()
 293    }
 294
 295    fn cancel_pending_prediction(
 296        &mut self,
 297        pending_prediction: PendingPrediction,
 298        cx: &mut Context<EditPredictionStore>,
 299    ) {
 300        self.cancelled_predictions.insert(pending_prediction.id);
 301
 302        if pending_prediction.drop_on_cancel {
 303            drop(pending_prediction.task);
 304        } else {
 305            cx.spawn(async move |this, cx| {
 306                let Some(prediction_id) = pending_prediction.task.await else {
 307                    return;
 308                };
 309
 310                this.update(cx, |this, _cx| {
 311                    this.reject_prediction(
 312                        prediction_id,
 313                        EditPredictionRejectReason::Canceled,
 314                        false,
 315                    );
 316                })
 317                .ok();
 318            })
 319            .detach()
 320        }
 321    }
 322
 323    fn active_buffer(
 324        &self,
 325        project: &Entity<Project>,
 326        cx: &App,
 327    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 328        let project = project.read(cx);
 329        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 330        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 331        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 332        Some((active_buffer, registered_buffer.last_position))
 333    }
 334}
 335
 336#[derive(Debug, Clone)]
 337struct CurrentEditPrediction {
 338    pub requested_by: PredictionRequestedBy,
 339    pub prediction: EditPrediction,
 340    pub was_shown: bool,
 341    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 342}
 343
 344impl CurrentEditPrediction {
 345    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 346        let Some(new_edits) = self
 347            .prediction
 348            .interpolate(&self.prediction.buffer.read(cx))
 349        else {
 350            return false;
 351        };
 352
 353        if self.prediction.buffer != old_prediction.prediction.buffer {
 354            return true;
 355        }
 356
 357        let Some(old_edits) = old_prediction
 358            .prediction
 359            .interpolate(&old_prediction.prediction.buffer.read(cx))
 360        else {
 361            return true;
 362        };
 363
 364        let requested_by_buffer_id = self.requested_by.buffer_id();
 365
 366        // This reduces the occurrence of UI thrash from replacing edits
 367        //
 368        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 369        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 370            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 371            && old_edits.len() == 1
 372            && new_edits.len() == 1
 373        {
 374            let (old_range, old_text) = &old_edits[0];
 375            let (new_range, new_text) = &new_edits[0];
 376            new_range == old_range && new_text.starts_with(old_text.as_ref())
 377        } else {
 378            true
 379        }
 380    }
 381}
 382
 383#[derive(Debug, Clone)]
 384enum PredictionRequestedBy {
 385    DiagnosticsUpdate,
 386    Buffer(EntityId),
 387}
 388
 389impl PredictionRequestedBy {
 390    pub fn buffer_id(&self) -> Option<EntityId> {
 391        match self {
 392            PredictionRequestedBy::DiagnosticsUpdate => None,
 393            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 394        }
 395    }
 396}
 397
 398#[derive(Debug)]
 399struct PendingPrediction {
 400    id: usize,
 401    task: Task<Option<EditPredictionId>>,
 402    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 403    /// If false, the task is awaited to completion so rejection can be reported.
 404    drop_on_cancel: bool,
 405}
 406
 407/// A prediction from the perspective of a buffer.
 408#[derive(Debug)]
 409enum BufferEditPrediction<'a> {
 410    Local { prediction: &'a EditPrediction },
 411    Jump { prediction: &'a EditPrediction },
 412}
 413
 414#[cfg(test)]
 415impl std::ops::Deref for BufferEditPrediction<'_> {
 416    type Target = EditPrediction;
 417
 418    fn deref(&self) -> &Self::Target {
 419        match self {
 420            BufferEditPrediction::Local { prediction } => prediction,
 421            BufferEditPrediction::Jump { prediction } => prediction,
 422        }
 423    }
 424}
 425
 426struct RegisteredBuffer {
 427    file: Option<Arc<dyn File>>,
 428    snapshot: TextBufferSnapshot,
 429    last_position: Option<Anchor>,
 430    _subscriptions: [gpui::Subscription; 2],
 431}
 432
 433#[derive(Clone)]
 434struct LastEvent {
 435    old_snapshot: TextBufferSnapshot,
 436    new_snapshot: TextBufferSnapshot,
 437    old_file: Option<Arc<dyn File>>,
 438    new_file: Option<Arc<dyn File>>,
 439    edit_range: Option<Range<Anchor>>,
 440    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 441    last_edit_time: Option<Instant>,
 442}
 443
 444impl LastEvent {
 445    pub fn finalize(
 446        &self,
 447        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 448        cx: &App,
 449    ) -> Option<StoredEvent> {
 450        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 451        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 452
 453        let in_open_source_repo =
 454            [self.new_file.as_ref(), self.old_file.as_ref()]
 455                .iter()
 456                .all(|file| {
 457                    file.is_some_and(|file| {
 458                        license_detection_watchers
 459                            .get(&file.worktree_id(cx))
 460                            .is_some_and(|watcher| watcher.is_project_open_source())
 461                    })
 462                });
 463
 464        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 465
 466        if path == old_path && diff.is_empty() {
 467            None
 468        } else {
 469            Some(StoredEvent {
 470                event: Arc::new(zeta_prompt::Event::BufferChange {
 471                    old_path,
 472                    path,
 473                    diff,
 474                    in_open_source_repo,
 475                    // TODO: Actually detect if this edit was predicted or not
 476                    predicted: false,
 477                }),
 478                old_snapshot: self.old_snapshot.clone(),
 479            })
 480        }
 481    }
 482
 483    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 484        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 485            return (self.clone(), None);
 486        };
 487
 488        let before = LastEvent {
 489            old_snapshot: self.old_snapshot.clone(),
 490            new_snapshot: boundary_snapshot.clone(),
 491            old_file: self.old_file.clone(),
 492            new_file: self.new_file.clone(),
 493            edit_range: None,
 494            snapshot_after_last_editing_pause: None,
 495            last_edit_time: self.last_edit_time,
 496        };
 497
 498        let after = LastEvent {
 499            old_snapshot: boundary_snapshot.clone(),
 500            new_snapshot: self.new_snapshot.clone(),
 501            old_file: self.old_file.clone(),
 502            new_file: self.new_file.clone(),
 503            edit_range: None,
 504            snapshot_after_last_editing_pause: None,
 505            last_edit_time: self.last_edit_time,
 506        };
 507
 508        (before, Some(after))
 509    }
 510}
 511
 512pub(crate) fn compute_diff_between_snapshots(
 513    old_snapshot: &TextBufferSnapshot,
 514    new_snapshot: &TextBufferSnapshot,
 515) -> Option<String> {
 516    let edits: Vec<Edit<usize>> = new_snapshot
 517        .edits_since::<usize>(&old_snapshot.version)
 518        .collect();
 519
 520    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 521
 522    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 523    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 524    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 525    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 526
 527    const CONTEXT_LINES: u32 = 3;
 528
 529    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 530    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 531    let old_context_end_row =
 532        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 533    let new_context_end_row =
 534        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 535
 536    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 537    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 538    let old_end_line_offset = old_snapshot
 539        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 540    let new_end_line_offset = new_snapshot
 541        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 542    let old_edit_range = old_start_line_offset..old_end_line_offset;
 543    let new_edit_range = new_start_line_offset..new_end_line_offset;
 544
 545    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 546    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 547
 548    let diff = language::unified_diff_with_offsets(
 549        &old_region_text,
 550        &new_region_text,
 551        old_context_start_row,
 552        new_context_start_row,
 553    );
 554
 555    Some(diff)
 556}
 557
 558fn buffer_path_with_id_fallback(
 559    file: Option<&Arc<dyn File>>,
 560    snapshot: &TextBufferSnapshot,
 561    cx: &App,
 562) -> Arc<Path> {
 563    if let Some(file) = file {
 564        file.full_path(cx).into()
 565    } else {
 566        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 567    }
 568}
 569
 570impl EditPredictionStore {
 571    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 572        cx.try_global::<EditPredictionStoreGlobal>()
 573            .map(|global| global.0.clone())
 574    }
 575
 576    pub fn global(
 577        client: &Arc<Client>,
 578        user_store: &Entity<UserStore>,
 579        cx: &mut App,
 580    ) -> Entity<Self> {
 581        cx.try_global::<EditPredictionStoreGlobal>()
 582            .map(|global| global.0.clone())
 583            .unwrap_or_else(|| {
 584                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 585                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 586                ep_store
 587            })
 588    }
 589
 590    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 591        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 592        let data_collection_choice = Self::load_data_collection_choice();
 593
 594        let llm_token = LlmApiToken::default();
 595
 596        let (reject_tx, reject_rx) = mpsc::unbounded();
 597        cx.background_spawn({
 598            let client = client.clone();
 599            let llm_token = llm_token.clone();
 600            let app_version = AppVersion::global(cx);
 601            let background_executor = cx.background_executor().clone();
 602            async move {
 603                Self::handle_rejected_predictions(
 604                    reject_rx,
 605                    client,
 606                    llm_token,
 607                    app_version,
 608                    background_executor,
 609                )
 610                .await
 611            }
 612        })
 613        .detach();
 614
 615        let this = Self {
 616            projects: HashMap::default(),
 617            client,
 618            user_store,
 619            llm_token,
 620            _llm_token_subscription: cx.subscribe(
 621                &refresh_llm_token_listener,
 622                |this, _listener, _event, cx| {
 623                    let client = this.client.clone();
 624                    let llm_token = this.llm_token.clone();
 625                    cx.spawn(async move |_this, _cx| {
 626                        llm_token.refresh(&client).await?;
 627                        anyhow::Ok(())
 628                    })
 629                    .detach_and_log_err(cx);
 630                },
 631            ),
 632            update_required: false,
 633            edit_prediction_model: EditPredictionModel::Zeta2 {
 634                version: Default::default(),
 635            },
 636            sweep_ai: SweepAi::new(cx),
 637            mercury: Mercury::new(cx),
 638            ollama: Ollama::new(),
 639
 640            data_collection_choice,
 641            reject_predictions_tx: reject_tx,
 642            rated_predictions: Default::default(),
 643            shown_predictions: Default::default(),
 644            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 645                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 646                Err(_) => None,
 647            },
 648        };
 649
 650        this
 651    }
 652
 653    #[cfg(test)]
 654    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 655        self.custom_predict_edits_url = Some(url.into());
 656    }
 657
 658    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 659        self.edit_prediction_model = model;
 660    }
 661
 662    pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
 663        use ui::IconName;
 664        match self.edit_prediction_model {
 665            EditPredictionModel::Sweep => {
 666                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 667                    .with_disabled(IconName::SweepAiDisabled)
 668                    .with_up(IconName::SweepAiUp)
 669                    .with_down(IconName::SweepAiDown)
 670                    .with_error(IconName::SweepAiError)
 671            }
 672            EditPredictionModel::Mercury => {
 673                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 674            }
 675            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
 676                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 677                    .with_disabled(IconName::ZedPredictDisabled)
 678                    .with_up(IconName::ZedPredictUp)
 679                    .with_down(IconName::ZedPredictDown)
 680                    .with_error(IconName::ZedPredictError)
 681            }
 682            EditPredictionModel::Ollama => {
 683                edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 684            }
 685        }
 686    }
 687
 688    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 689        self.sweep_ai.api_token.read(cx).has_key()
 690    }
 691
 692    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 693        self.mercury.api_token.read(cx).has_key()
 694    }
 695
 696    pub fn clear_history(&mut self) {
 697        for project_state in self.projects.values_mut() {
 698            project_state.events.clear();
 699            project_state.last_event.take();
 700        }
 701    }
 702
 703    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 704        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 705            project_state.events.clear();
 706            project_state.last_event.take();
 707        }
 708    }
 709
 710    pub fn edit_history_for_project(
 711        &self,
 712        project: &Entity<Project>,
 713        cx: &App,
 714    ) -> Vec<StoredEvent> {
 715        self.projects
 716            .get(&project.entity_id())
 717            .map(|project_state| project_state.events(cx))
 718            .unwrap_or_default()
 719    }
 720
 721    pub fn edit_history_for_project_with_pause_split_last_event(
 722        &self,
 723        project: &Entity<Project>,
 724        cx: &App,
 725    ) -> Vec<StoredEvent> {
 726        self.projects
 727            .get(&project.entity_id())
 728            .map(|project_state| project_state.events_split_by_pause(cx))
 729            .unwrap_or_default()
 730    }
 731
 732    pub fn context_for_project<'a>(
 733        &'a self,
 734        project: &Entity<Project>,
 735        cx: &'a mut App,
 736    ) -> Vec<RelatedFile> {
 737        self.projects
 738            .get(&project.entity_id())
 739            .map(|project| {
 740                project
 741                    .context
 742                    .update(cx, |context, cx| context.related_files(cx))
 743            })
 744            .unwrap_or_default()
 745    }
 746
 747    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 748        self.projects
 749            .get(&project.entity_id())
 750            .and_then(|project| project.copilot.clone())
 751    }
 752
 753    pub fn start_copilot_for_project(
 754        &mut self,
 755        project: &Entity<Project>,
 756        cx: &mut Context<Self>,
 757    ) -> Option<Entity<Copilot>> {
 758        if DisableAiSettings::get(None, cx).disable_ai {
 759            return None;
 760        }
 761        let state = self.get_or_init_project(project, cx);
 762
 763        if state.copilot.is_some() {
 764            return state.copilot.clone();
 765        }
 766        let _project = project.clone();
 767        let project = project.read(cx);
 768
 769        let node = project.node_runtime().cloned();
 770        if let Some(node) = node {
 771            let next_id = project.languages().next_language_server_id();
 772            let fs = project.fs().clone();
 773
 774            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 775            state.copilot = Some(copilot.clone());
 776            Some(copilot)
 777        } else {
 778            None
 779        }
 780    }
 781
 782    pub fn context_for_project_with_buffers<'a>(
 783        &'a self,
 784        project: &Entity<Project>,
 785        cx: &'a mut App,
 786    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 787        self.projects
 788            .get(&project.entity_id())
 789            .map(|project| {
 790                project
 791                    .context
 792                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 793            })
 794            .unwrap_or_default()
 795    }
 796
 797    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 798        if matches!(
 799            self.edit_prediction_model,
 800            EditPredictionModel::Zeta2 { .. }
 801        ) {
 802            self.user_store.read(cx).edit_prediction_usage()
 803        } else {
 804            None
 805        }
 806    }
 807
 808    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 809        self.get_or_init_project(project, cx);
 810    }
 811
 812    pub fn register_buffer(
 813        &mut self,
 814        buffer: &Entity<Buffer>,
 815        project: &Entity<Project>,
 816        cx: &mut Context<Self>,
 817    ) {
 818        let project_state = self.get_or_init_project(project, cx);
 819        Self::register_buffer_impl(project_state, buffer, project, cx);
 820    }
 821
 822    fn get_or_init_project(
 823        &mut self,
 824        project: &Entity<Project>,
 825        cx: &mut Context<Self>,
 826    ) -> &mut ProjectState {
 827        let entity_id = project.entity_id();
 828        self.projects
 829            .entry(entity_id)
 830            .or_insert_with(|| ProjectState {
 831                context: {
 832                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 833                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 834                        this.handle_excerpt_store_event(entity_id, event);
 835                    })
 836                    .detach();
 837                    related_excerpt_store
 838                },
 839                events: VecDeque::new(),
 840                last_event: None,
 841                recent_paths: VecDeque::new(),
 842                debug_tx: None,
 843                registered_buffers: HashMap::default(),
 844                current_prediction: None,
 845                cancelled_predictions: HashSet::default(),
 846                pending_predictions: ArrayVec::new(),
 847                next_pending_prediction_id: 0,
 848                last_prediction_refresh: None,
 849                license_detection_watchers: HashMap::default(),
 850                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 851                _subscriptions: [
 852                    cx.subscribe(&project, Self::handle_project_event),
 853                    cx.observe_release(&project, move |this, _, cx| {
 854                        this.projects.remove(&entity_id);
 855                        cx.notify();
 856                    }),
 857                ],
 858                copilot: None,
 859            })
 860    }
 861
 862    pub fn remove_project(&mut self, project: &Entity<Project>) {
 863        self.projects.remove(&project.entity_id());
 864    }
 865
 866    fn handle_excerpt_store_event(
 867        &mut self,
 868        project_entity_id: EntityId,
 869        event: &RelatedExcerptStoreEvent,
 870    ) {
 871        if let Some(project_state) = self.projects.get(&project_entity_id) {
 872            if let Some(debug_tx) = project_state.debug_tx.clone() {
 873                match event {
 874                    RelatedExcerptStoreEvent::StartedRefresh => {
 875                        debug_tx
 876                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 877                                ContextRetrievalStartedDebugEvent {
 878                                    project_entity_id: project_entity_id,
 879                                    timestamp: Instant::now(),
 880                                    search_prompt: String::new(),
 881                                },
 882                            ))
 883                            .ok();
 884                    }
 885                    RelatedExcerptStoreEvent::FinishedRefresh {
 886                        cache_hit_count,
 887                        cache_miss_count,
 888                        mean_definition_latency,
 889                        max_definition_latency,
 890                    } => {
 891                        debug_tx
 892                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 893                                ContextRetrievalFinishedDebugEvent {
 894                                    project_entity_id: project_entity_id,
 895                                    timestamp: Instant::now(),
 896                                    metadata: vec![
 897                                        (
 898                                            "Cache Hits",
 899                                            format!(
 900                                                "{}/{}",
 901                                                cache_hit_count,
 902                                                cache_hit_count + cache_miss_count
 903                                            )
 904                                            .into(),
 905                                        ),
 906                                        (
 907                                            "Max LSP Time",
 908                                            format!("{} ms", max_definition_latency.as_millis())
 909                                                .into(),
 910                                        ),
 911                                        (
 912                                            "Mean LSP Time",
 913                                            format!("{} ms", mean_definition_latency.as_millis())
 914                                                .into(),
 915                                        ),
 916                                    ],
 917                                },
 918                            ))
 919                            .ok();
 920                    }
 921                }
 922            }
 923        }
 924    }
 925
 926    pub fn debug_info(
 927        &mut self,
 928        project: &Entity<Project>,
 929        cx: &mut Context<Self>,
 930    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 931        let project_state = self.get_or_init_project(project, cx);
 932        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 933        project_state.debug_tx = Some(debug_watch_tx);
 934        debug_watch_rx
 935    }
 936
 937    fn handle_project_event(
 938        &mut self,
 939        project: Entity<Project>,
 940        event: &project::Event,
 941        cx: &mut Context<Self>,
 942    ) {
 943        // TODO [zeta2] init with recent paths
 944        match event {
 945            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 946                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 947                    return;
 948                };
 949                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 950                if let Some(path) = path {
 951                    if let Some(ix) = project_state
 952                        .recent_paths
 953                        .iter()
 954                        .position(|probe| probe == &path)
 955                    {
 956                        project_state.recent_paths.remove(ix);
 957                    }
 958                    project_state.recent_paths.push_front(path);
 959                }
 960            }
 961            project::Event::DiagnosticsUpdated { .. } => {
 962                if cx.has_flag::<Zeta2FeatureFlag>() {
 963                    self.refresh_prediction_from_diagnostics(project, cx);
 964                }
 965            }
 966            _ => (),
 967        }
 968    }
 969
 970    fn register_buffer_impl<'a>(
 971        project_state: &'a mut ProjectState,
 972        buffer: &Entity<Buffer>,
 973        project: &Entity<Project>,
 974        cx: &mut Context<Self>,
 975    ) -> &'a mut RegisteredBuffer {
 976        let buffer_id = buffer.entity_id();
 977
 978        if let Some(file) = buffer.read(cx).file() {
 979            let worktree_id = file.worktree_id(cx);
 980            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 981                project_state
 982                    .license_detection_watchers
 983                    .entry(worktree_id)
 984                    .or_insert_with(|| {
 985                        let project_entity_id = project.entity_id();
 986                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 987                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 988                            else {
 989                                return;
 990                            };
 991                            project_state
 992                                .license_detection_watchers
 993                                .remove(&worktree_id);
 994                        })
 995                        .detach();
 996                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 997                    });
 998            }
 999        }
1000
1001        match project_state.registered_buffers.entry(buffer_id) {
1002            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1003            hash_map::Entry::Vacant(entry) => {
1004                let buf = buffer.read(cx);
1005                let snapshot = buf.text_snapshot();
1006                let file = buf.file().cloned();
1007                let project_entity_id = project.entity_id();
1008                entry.insert(RegisteredBuffer {
1009                    snapshot,
1010                    file,
1011                    last_position: None,
1012                    _subscriptions: [
1013                        cx.subscribe(buffer, {
1014                            let project = project.downgrade();
1015                            move |this, buffer, event, cx| {
1016                                if let language::BufferEvent::Edited = event
1017                                    && let Some(project) = project.upgrade()
1018                                {
1019                                    this.report_changes_for_buffer(&buffer, &project, cx);
1020                                }
1021                            }
1022                        }),
1023                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1024                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1025                            else {
1026                                return;
1027                            };
1028                            project_state.registered_buffers.remove(&buffer_id);
1029                        }),
1030                    ],
1031                })
1032            }
1033        }
1034    }
1035
1036    fn report_changes_for_buffer(
1037        &mut self,
1038        buffer: &Entity<Buffer>,
1039        project: &Entity<Project>,
1040        cx: &mut Context<Self>,
1041    ) {
1042        let project_state = self.get_or_init_project(project, cx);
1043        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1044
1045        let buf = buffer.read(cx);
1046        let new_file = buf.file().cloned();
1047        let new_snapshot = buf.text_snapshot();
1048        if new_snapshot.version == registered_buffer.snapshot.version {
1049            return;
1050        }
1051
1052        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1053        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1054        let mut num_edits = 0usize;
1055        let mut total_deleted = 0usize;
1056        let mut total_inserted = 0usize;
1057        let mut edit_range: Option<Range<Anchor>> = None;
1058        let mut last_offset: Option<usize> = None;
1059
1060        for (edit, anchor_range) in
1061            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1062        {
1063            num_edits += 1;
1064            total_deleted += edit.old.len();
1065            total_inserted += edit.new.len();
1066            edit_range = Some(match edit_range {
1067                None => anchor_range,
1068                Some(acc) => acc.start..anchor_range.end,
1069            });
1070            last_offset = Some(edit.new.end);
1071        }
1072
1073        if num_edits > 0 {
1074            let action_type = match (total_deleted, total_inserted, num_edits) {
1075                (0, ins, n) if ins == n => UserActionType::InsertChar,
1076                (0, _, _) => UserActionType::InsertSelection,
1077                (del, 0, n) if del == n => UserActionType::DeleteChar,
1078                (_, 0, _) => UserActionType::DeleteSelection,
1079                (_, ins, n) if ins == n => UserActionType::InsertChar,
1080                (_, _, _) => UserActionType::InsertSelection,
1081            };
1082
1083            if let Some(offset) = last_offset {
1084                let point = new_snapshot.offset_to_point(offset);
1085                let timestamp_epoch_ms = SystemTime::now()
1086                    .duration_since(UNIX_EPOCH)
1087                    .map(|d| d.as_millis() as u64)
1088                    .unwrap_or(0);
1089                project_state.record_user_action(UserActionRecord {
1090                    action_type,
1091                    buffer_id: buffer.entity_id(),
1092                    line_number: point.row,
1093                    offset,
1094                    timestamp_epoch_ms,
1095                });
1096            }
1097        }
1098
1099        let events = &mut project_state.events;
1100
1101        let now = cx.background_executor().now();
1102        if let Some(last_event) = project_state.last_event.as_mut() {
1103            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1104                == last_event.new_snapshot.remote_id()
1105                && old_snapshot.version == last_event.new_snapshot.version;
1106
1107            let should_coalesce = is_next_snapshot_of_same_buffer
1108                && edit_range
1109                    .as_ref()
1110                    .zip(last_event.edit_range.as_ref())
1111                    .is_some_and(|(a, b)| {
1112                        let a = a.to_point(&new_snapshot);
1113                        let b = b.to_point(&new_snapshot);
1114                        if a.start > b.end {
1115                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1116                        } else if b.start > a.end {
1117                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1118                        } else {
1119                            true
1120                        }
1121                    });
1122
1123            if should_coalesce {
1124                let pause_elapsed = last_event
1125                    .last_edit_time
1126                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1127                    .unwrap_or(false);
1128                if pause_elapsed {
1129                    last_event.snapshot_after_last_editing_pause =
1130                        Some(last_event.new_snapshot.clone());
1131                }
1132
1133                last_event.edit_range = edit_range;
1134                last_event.new_snapshot = new_snapshot;
1135                last_event.last_edit_time = Some(now);
1136                return;
1137            }
1138        }
1139
1140        if let Some(event) = project_state.last_event.take() {
1141            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1142                if events.len() + 1 >= EVENT_COUNT_MAX {
1143                    events.pop_front();
1144                }
1145                events.push_back(event);
1146            }
1147        }
1148
1149        project_state.last_event = Some(LastEvent {
1150            old_file,
1151            new_file,
1152            old_snapshot,
1153            new_snapshot,
1154            edit_range,
1155            snapshot_after_last_editing_pause: None,
1156            last_edit_time: Some(now),
1157        });
1158    }
1159
1160    fn prediction_at(
1161        &mut self,
1162        buffer: &Entity<Buffer>,
1163        position: Option<language::Anchor>,
1164        project: &Entity<Project>,
1165        cx: &App,
1166    ) -> Option<BufferEditPrediction<'_>> {
1167        let project_state = self.projects.get_mut(&project.entity_id())?;
1168        if let Some(position) = position
1169            && let Some(buffer) = project_state
1170                .registered_buffers
1171                .get_mut(&buffer.entity_id())
1172        {
1173            buffer.last_position = Some(position);
1174        }
1175
1176        let CurrentEditPrediction {
1177            requested_by,
1178            prediction,
1179            ..
1180        } = project_state.current_prediction.as_ref()?;
1181
1182        if prediction.targets_buffer(buffer.read(cx)) {
1183            Some(BufferEditPrediction::Local { prediction })
1184        } else {
1185            let show_jump = match requested_by {
1186                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1187                    requested_by_buffer_id == &buffer.entity_id()
1188                }
1189                PredictionRequestedBy::DiagnosticsUpdate => true,
1190            };
1191
1192            if show_jump {
1193                Some(BufferEditPrediction::Jump { prediction })
1194            } else {
1195                None
1196            }
1197        }
1198    }
1199
1200    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1201        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1202            return;
1203        };
1204
1205        let Some(current_prediction) = project_state.current_prediction.take() else {
1206            return;
1207        };
1208
1209        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1210            project_state.cancel_pending_prediction(pending_prediction, cx);
1211        }
1212
1213        match self.edit_prediction_model {
1214            EditPredictionModel::Sweep => {
1215                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1216            }
1217            EditPredictionModel::Mercury | EditPredictionModel::Ollama => {}
1218            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1219                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1220            }
1221        }
1222    }
1223
1224    async fn handle_rejected_predictions(
1225        rx: UnboundedReceiver<EditPredictionRejection>,
1226        client: Arc<Client>,
1227        llm_token: LlmApiToken,
1228        app_version: Version,
1229        background_executor: BackgroundExecutor,
1230    ) {
1231        let mut rx = std::pin::pin!(rx.peekable());
1232        let mut batched = Vec::new();
1233
1234        while let Some(rejection) = rx.next().await {
1235            batched.push(rejection);
1236
1237            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1238                select_biased! {
1239                    next = rx.as_mut().peek().fuse() => {
1240                        if next.is_some() {
1241                            continue;
1242                        }
1243                    }
1244                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1245                }
1246            }
1247
1248            let url = client
1249                .http_client()
1250                .build_zed_llm_url("/predict_edits/reject", &[])
1251                .unwrap();
1252
1253            let flush_count = batched
1254                .len()
1255                // in case items have accumulated after failure
1256                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1257            let start = batched.len() - flush_count;
1258
1259            let body = RejectEditPredictionsBodyRef {
1260                rejections: &batched[start..],
1261            };
1262
1263            let result = Self::send_api_request::<()>(
1264                |builder| {
1265                    let req = builder
1266                        .uri(url.as_ref())
1267                        .body(serde_json::to_string(&body)?.into());
1268                    anyhow::Ok(req?)
1269                },
1270                client.clone(),
1271                llm_token.clone(),
1272                app_version.clone(),
1273                true,
1274            )
1275            .await;
1276
1277            if result.log_err().is_some() {
1278                batched.drain(start..);
1279            }
1280        }
1281    }
1282
1283    fn reject_current_prediction(
1284        &mut self,
1285        reason: EditPredictionRejectReason,
1286        project: &Entity<Project>,
1287    ) {
1288        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1289            project_state.pending_predictions.clear();
1290            if let Some(prediction) = project_state.current_prediction.take() {
1291                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1292            }
1293        };
1294    }
1295
1296    fn did_show_current_prediction(
1297        &mut self,
1298        project: &Entity<Project>,
1299        display_type: edit_prediction_types::SuggestionDisplayType,
1300        cx: &mut Context<Self>,
1301    ) {
1302        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1303            return;
1304        };
1305
1306        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1307            return;
1308        };
1309
1310        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1311        let previous_shown_with = current_prediction.shown_with;
1312
1313        if previous_shown_with.is_none() || !is_jump {
1314            current_prediction.shown_with = Some(display_type);
1315        }
1316
1317        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1318
1319        if is_first_non_jump_show {
1320            current_prediction.was_shown = true;
1321        }
1322
1323        let display_type_changed = previous_shown_with != Some(display_type);
1324
1325        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1326            sweep_ai::edit_prediction_shown(
1327                &self.sweep_ai,
1328                self.client.clone(),
1329                &current_prediction.prediction,
1330                display_type,
1331                cx,
1332            );
1333        }
1334
1335        if is_first_non_jump_show {
1336            self.shown_predictions
1337                .push_front(current_prediction.prediction.clone());
1338            if self.shown_predictions.len() > 50 {
1339                let completion = self.shown_predictions.pop_back().unwrap();
1340                self.rated_predictions.remove(&completion.id);
1341            }
1342        }
1343    }
1344
1345    fn reject_prediction(
1346        &mut self,
1347        prediction_id: EditPredictionId,
1348        reason: EditPredictionRejectReason,
1349        was_shown: bool,
1350    ) {
1351        match self.edit_prediction_model {
1352            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1353                if self.custom_predict_edits_url.is_some() {
1354                    return;
1355                }
1356            }
1357            EditPredictionModel::Sweep
1358            | EditPredictionModel::Mercury
1359            | EditPredictionModel::Ollama => return,
1360        }
1361
1362        self.reject_predictions_tx
1363            .unbounded_send(EditPredictionRejection {
1364                request_id: prediction_id.to_string(),
1365                reason,
1366                was_shown,
1367            })
1368            .log_err();
1369    }
1370
1371    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1372        self.projects
1373            .get(&project.entity_id())
1374            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1375    }
1376
1377    pub fn refresh_prediction_from_buffer(
1378        &mut self,
1379        project: Entity<Project>,
1380        buffer: Entity<Buffer>,
1381        position: language::Anchor,
1382        cx: &mut Context<Self>,
1383    ) {
1384        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1385            let Some(request_task) = this
1386                .update(cx, |this, cx| {
1387                    this.request_prediction(
1388                        &project,
1389                        &buffer,
1390                        position,
1391                        PredictEditsRequestTrigger::Other,
1392                        cx,
1393                    )
1394                })
1395                .log_err()
1396            else {
1397                return Task::ready(anyhow::Ok(None));
1398            };
1399
1400            cx.spawn(async move |_cx| {
1401                request_task.await.map(|prediction_result| {
1402                    prediction_result.map(|prediction_result| {
1403                        (
1404                            prediction_result,
1405                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1406                        )
1407                    })
1408                })
1409            })
1410        })
1411    }
1412
1413    pub fn refresh_prediction_from_diagnostics(
1414        &mut self,
1415        project: Entity<Project>,
1416        cx: &mut Context<Self>,
1417    ) {
1418        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1419            return;
1420        };
1421
1422        // Prefer predictions from buffer
1423        if project_state.current_prediction.is_some() {
1424            return;
1425        };
1426
1427        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1428            let Some((active_buffer, snapshot, cursor_point)) = this
1429                .read_with(cx, |this, cx| {
1430                    let project_state = this.projects.get(&project.entity_id())?;
1431                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1432                    let snapshot = buffer.read(cx).snapshot();
1433
1434                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1435                        return None;
1436                    }
1437
1438                    let cursor_point = position
1439                        .map(|pos| pos.to_point(&snapshot))
1440                        .unwrap_or_default();
1441
1442                    Some((buffer, snapshot, cursor_point))
1443                })
1444                .log_err()
1445                .flatten()
1446            else {
1447                return Task::ready(anyhow::Ok(None));
1448            };
1449
1450            cx.spawn(async move |cx| {
1451                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1452                    active_buffer,
1453                    &snapshot,
1454                    Default::default(),
1455                    cursor_point,
1456                    &project,
1457                    cx,
1458                )
1459                .await?
1460                else {
1461                    return anyhow::Ok(None);
1462                };
1463
1464                let Some(prediction_result) = this
1465                    .update(cx, |this, cx| {
1466                        this.request_prediction(
1467                            &project,
1468                            &jump_buffer,
1469                            jump_position,
1470                            PredictEditsRequestTrigger::Diagnostics,
1471                            cx,
1472                        )
1473                    })?
1474                    .await?
1475                else {
1476                    return anyhow::Ok(None);
1477                };
1478
1479                this.update(cx, |this, cx| {
1480                    Some((
1481                        if this
1482                            .get_or_init_project(&project, cx)
1483                            .current_prediction
1484                            .is_none()
1485                        {
1486                            prediction_result
1487                        } else {
1488                            EditPredictionResult {
1489                                id: prediction_result.id,
1490                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1491                            }
1492                        },
1493                        PredictionRequestedBy::DiagnosticsUpdate,
1494                    ))
1495                })
1496            })
1497        });
1498    }
1499
1500    fn predictions_enabled_at(
1501        snapshot: &BufferSnapshot,
1502        position: Option<language::Anchor>,
1503        cx: &App,
1504    ) -> bool {
1505        let file = snapshot.file();
1506        let all_settings = all_language_settings(file, cx);
1507        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1508            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1509        {
1510            return false;
1511        }
1512
1513        if let Some(last_position) = position {
1514            let settings = snapshot.settings_at(last_position, cx);
1515
1516            if !settings.edit_predictions_disabled_in.is_empty()
1517                && let Some(scope) = snapshot.language_scope_at(last_position)
1518                && let Some(scope_name) = scope.override_name()
1519                && settings
1520                    .edit_predictions_disabled_in
1521                    .iter()
1522                    .any(|s| s == scope_name)
1523            {
1524                return false;
1525            }
1526        }
1527
1528        true
1529    }
1530
1531    #[cfg(not(test))]
1532    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1533    #[cfg(test)]
1534    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1535
1536    fn queue_prediction_refresh(
1537        &mut self,
1538        project: Entity<Project>,
1539        throttle_entity: EntityId,
1540        cx: &mut Context<Self>,
1541        do_refresh: impl FnOnce(
1542            WeakEntity<Self>,
1543            &mut AsyncApp,
1544        )
1545            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1546        + 'static,
1547    ) {
1548        let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1549        let drop_on_cancel = is_ollama;
1550        let max_pending_predictions = if is_ollama { 1 } else { 2 };
1551        let project_state = self.get_or_init_project(&project, cx);
1552        let pending_prediction_id = project_state.next_pending_prediction_id;
1553        project_state.next_pending_prediction_id += 1;
1554        let last_request = project_state.last_prediction_refresh;
1555
1556        let task = cx.spawn(async move |this, cx| {
1557            if let Some((last_entity, last_timestamp)) = last_request
1558                && throttle_entity == last_entity
1559                && let Some(timeout) =
1560                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1561            {
1562                cx.background_executor().timer(timeout).await;
1563            }
1564
1565            // If this task was cancelled before the throttle timeout expired,
1566            // do not perform a request.
1567            let mut is_cancelled = true;
1568            this.update(cx, |this, cx| {
1569                let project_state = this.get_or_init_project(&project, cx);
1570                if !project_state
1571                    .cancelled_predictions
1572                    .remove(&pending_prediction_id)
1573                {
1574                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1575                    is_cancelled = false;
1576                }
1577            })
1578            .ok();
1579            if is_cancelled {
1580                return None;
1581            }
1582
1583            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1584            let new_prediction_id = new_prediction_result
1585                .as_ref()
1586                .map(|(prediction, _)| prediction.id.clone());
1587
1588            // When a prediction completes, remove it from the pending list, and cancel
1589            // any pending predictions that were enqueued before it.
1590            this.update(cx, |this, cx| {
1591                let project_state = this.get_or_init_project(&project, cx);
1592
1593                let is_cancelled = project_state
1594                    .cancelled_predictions
1595                    .remove(&pending_prediction_id);
1596
1597                let new_current_prediction = if !is_cancelled
1598                    && let Some((prediction_result, requested_by)) = new_prediction_result
1599                {
1600                    match prediction_result.prediction {
1601                        Ok(prediction) => {
1602                            let new_prediction = CurrentEditPrediction {
1603                                requested_by,
1604                                prediction,
1605                                was_shown: false,
1606                                shown_with: None,
1607                            };
1608
1609                            if let Some(current_prediction) =
1610                                project_state.current_prediction.as_ref()
1611                            {
1612                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1613                                {
1614                                    this.reject_current_prediction(
1615                                        EditPredictionRejectReason::Replaced,
1616                                        &project,
1617                                    );
1618
1619                                    Some(new_prediction)
1620                                } else {
1621                                    this.reject_prediction(
1622                                        new_prediction.prediction.id,
1623                                        EditPredictionRejectReason::CurrentPreferred,
1624                                        false,
1625                                    );
1626                                    None
1627                                }
1628                            } else {
1629                                Some(new_prediction)
1630                            }
1631                        }
1632                        Err(reject_reason) => {
1633                            this.reject_prediction(prediction_result.id, reject_reason, false);
1634                            None
1635                        }
1636                    }
1637                } else {
1638                    None
1639                };
1640
1641                let project_state = this.get_or_init_project(&project, cx);
1642
1643                if let Some(new_prediction) = new_current_prediction {
1644                    project_state.current_prediction = Some(new_prediction);
1645                }
1646
1647                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1648                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1649                    if pending_prediction.id == pending_prediction_id {
1650                        pending_predictions.remove(ix);
1651                        for pending_prediction in pending_predictions.drain(0..ix) {
1652                            project_state.cancel_pending_prediction(pending_prediction, cx)
1653                        }
1654                        break;
1655                    }
1656                }
1657                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1658                cx.notify();
1659            })
1660            .ok();
1661
1662            new_prediction_id
1663        });
1664
1665        if project_state.pending_predictions.len() < max_pending_predictions {
1666            project_state.pending_predictions.push(PendingPrediction {
1667                id: pending_prediction_id,
1668                task,
1669                drop_on_cancel,
1670            });
1671        } else {
1672            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1673            project_state.pending_predictions.push(PendingPrediction {
1674                id: pending_prediction_id,
1675                task,
1676                drop_on_cancel,
1677            });
1678            project_state.cancel_pending_prediction(pending_prediction, cx);
1679        }
1680    }
1681
1682    pub fn request_prediction(
1683        &mut self,
1684        project: &Entity<Project>,
1685        active_buffer: &Entity<Buffer>,
1686        position: language::Anchor,
1687        trigger: PredictEditsRequestTrigger,
1688        cx: &mut Context<Self>,
1689    ) -> Task<Result<Option<EditPredictionResult>>> {
1690        self.request_prediction_internal(
1691            project.clone(),
1692            active_buffer.clone(),
1693            position,
1694            trigger,
1695            cx.has_flag::<Zeta2FeatureFlag>(),
1696            cx,
1697        )
1698    }
1699
1700    fn request_prediction_internal(
1701        &mut self,
1702        project: Entity<Project>,
1703        active_buffer: Entity<Buffer>,
1704        position: language::Anchor,
1705        trigger: PredictEditsRequestTrigger,
1706        allow_jump: bool,
1707        cx: &mut Context<Self>,
1708    ) -> Task<Result<Option<EditPredictionResult>>> {
1709        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1710
1711        self.get_or_init_project(&project, cx);
1712        let project_state = self.projects.get(&project.entity_id()).unwrap();
1713        let stored_events = project_state.events_split_by_pause(cx);
1714        let has_events = !stored_events.is_empty();
1715        let events: Vec<Arc<zeta_prompt::Event>> =
1716            stored_events.into_iter().map(|e| e.event).collect();
1717        let debug_tx = project_state.debug_tx.clone();
1718
1719        let snapshot = active_buffer.read(cx).snapshot();
1720        let cursor_point = position.to_point(&snapshot);
1721        let current_offset = position.to_offset(&snapshot);
1722
1723        let mut user_actions: Vec<UserActionRecord> =
1724            project_state.user_actions.iter().cloned().collect();
1725
1726        if let Some(last_action) = user_actions.last() {
1727            if last_action.buffer_id == active_buffer.entity_id()
1728                && current_offset != last_action.offset
1729            {
1730                let timestamp_epoch_ms = SystemTime::now()
1731                    .duration_since(UNIX_EPOCH)
1732                    .map(|d| d.as_millis() as u64)
1733                    .unwrap_or(0);
1734                user_actions.push(UserActionRecord {
1735                    action_type: UserActionType::CursorMovement,
1736                    buffer_id: active_buffer.entity_id(),
1737                    line_number: cursor_point.row,
1738                    offset: current_offset,
1739                    timestamp_epoch_ms,
1740                });
1741            }
1742        }
1743        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1744        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1745        let diagnostic_search_range =
1746            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1747
1748        let related_files = self.context_for_project(&project, cx);
1749
1750        let inputs = EditPredictionModelInput {
1751            project: project.clone(),
1752            buffer: active_buffer.clone(),
1753            snapshot: snapshot.clone(),
1754            position,
1755            events,
1756            related_files,
1757            recent_paths: project_state.recent_paths.clone(),
1758            trigger,
1759            diagnostic_search_range: diagnostic_search_range.clone(),
1760            debug_tx,
1761            user_actions,
1762        };
1763
1764        let can_collect_example = snapshot
1765            .file()
1766            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1767            && self.can_collect_events(&inputs.events, cx)
1768            && self.can_collect_related_files(&project, cx);
1769
1770        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1771            let events_for_capture =
1772                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1773            let related_files_for_capture = inputs.related_files.clone();
1774            if let Some(example_task) = capture_example::capture_example(
1775                project.clone(),
1776                active_buffer.clone(),
1777                position,
1778                events_for_capture,
1779                related_files_for_capture,
1780                false,
1781                cx,
1782            ) {
1783                cx.spawn(async move |_this, _cx| {
1784                    let example = example_task.await?;
1785                    telemetry::event!("Edit Prediction Example Captured", example = example);
1786                    anyhow::Ok(())
1787                })
1788                .detach_and_log_err(cx);
1789            }
1790        }
1791        let task = match self.edit_prediction_model {
1792            EditPredictionModel::Zeta1 => {
1793                if should_send_testing_zeta2_request() {
1794                    let mut zeta2_inputs = inputs.clone();
1795                    zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1796                    zeta2::request_prediction_with_zeta2(
1797                        self,
1798                        zeta2_inputs,
1799                        Default::default(),
1800                        cx,
1801                    )
1802                    .detach();
1803                }
1804                zeta1::request_prediction_with_zeta1(self, inputs, cx)
1805            }
1806            EditPredictionModel::Zeta2 { version } => {
1807                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1808            }
1809            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1810            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1811            EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1812        };
1813
1814        cx.spawn(async move |this, cx| {
1815            let prediction = task.await?;
1816
1817            if prediction.is_none() && allow_jump {
1818                let cursor_point = position.to_point(&snapshot);
1819                if has_events
1820                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1821                        active_buffer.clone(),
1822                        &snapshot,
1823                        diagnostic_search_range,
1824                        cursor_point,
1825                        &project,
1826                        cx,
1827                    )
1828                    .await?
1829                {
1830                    return this
1831                        .update(cx, |this, cx| {
1832                            this.request_prediction_internal(
1833                                project,
1834                                jump_buffer,
1835                                jump_position,
1836                                trigger,
1837                                false,
1838                                cx,
1839                            )
1840                        })?
1841                        .await;
1842                }
1843
1844                return anyhow::Ok(None);
1845            }
1846
1847            Ok(prediction)
1848        })
1849    }
1850
1851    async fn next_diagnostic_location(
1852        active_buffer: Entity<Buffer>,
1853        active_buffer_snapshot: &BufferSnapshot,
1854        active_buffer_diagnostic_search_range: Range<Point>,
1855        active_buffer_cursor_point: Point,
1856        project: &Entity<Project>,
1857        cx: &mut AsyncApp,
1858    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1859        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1860        let mut jump_location = active_buffer_snapshot
1861            .diagnostic_groups(None)
1862            .into_iter()
1863            .filter_map(|(_, group)| {
1864                let range = &group.entries[group.primary_ix]
1865                    .range
1866                    .to_point(&active_buffer_snapshot);
1867                if range.overlaps(&active_buffer_diagnostic_search_range) {
1868                    None
1869                } else {
1870                    Some(range.start)
1871                }
1872            })
1873            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1874            .map(|position| {
1875                (
1876                    active_buffer.clone(),
1877                    active_buffer_snapshot.anchor_before(position),
1878                )
1879            });
1880
1881        if jump_location.is_none() {
1882            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1883                let file = buffer.file()?;
1884
1885                Some(ProjectPath {
1886                    worktree_id: file.worktree_id(cx),
1887                    path: file.path().clone(),
1888                })
1889            });
1890
1891            let buffer_task = project.update(cx, |project, cx| {
1892                let (path, _, _) = project
1893                    .diagnostic_summaries(false, cx)
1894                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1895                    .max_by_key(|(path, _, _)| {
1896                        // find the buffer with errors that shares most parent directories
1897                        path.path
1898                            .components()
1899                            .zip(
1900                                active_buffer_path
1901                                    .as_ref()
1902                                    .map(|p| p.path.components())
1903                                    .unwrap_or_default(),
1904                            )
1905                            .take_while(|(a, b)| a == b)
1906                            .count()
1907                    })?;
1908
1909                Some(project.open_buffer(path, cx))
1910            });
1911
1912            if let Some(buffer_task) = buffer_task {
1913                let closest_buffer = buffer_task.await?;
1914
1915                jump_location = closest_buffer
1916                    .read_with(cx, |buffer, _cx| {
1917                        buffer
1918                            .buffer_diagnostics(None)
1919                            .into_iter()
1920                            .min_by_key(|entry| entry.diagnostic.severity)
1921                            .map(|entry| entry.range.start)
1922                    })
1923                    .map(|position| (closest_buffer, position));
1924            }
1925        }
1926
1927        anyhow::Ok(jump_location)
1928    }
1929
1930    async fn send_raw_llm_request(
1931        request: RawCompletionRequest,
1932        client: Arc<Client>,
1933        custom_url: Option<Arc<Url>>,
1934        llm_token: LlmApiToken,
1935        app_version: Version,
1936    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1937        let url = if let Some(custom_url) = custom_url {
1938            custom_url.as_ref().clone()
1939        } else {
1940            client
1941                .http_client()
1942                .build_zed_llm_url("/predict_edits/raw", &[])?
1943        };
1944
1945        Self::send_api_request(
1946            |builder| {
1947                let req = builder
1948                    .uri(url.as_ref())
1949                    .body(serde_json::to_string(&request)?.into());
1950                Ok(req?)
1951            },
1952            client,
1953            llm_token,
1954            app_version,
1955            true,
1956        )
1957        .await
1958    }
1959
1960    pub(crate) async fn send_v3_request(
1961        input: ZetaPromptInput,
1962        prompt_version: ZetaVersion,
1963        client: Arc<Client>,
1964        llm_token: LlmApiToken,
1965        app_version: Version,
1966        trigger: PredictEditsRequestTrigger,
1967    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1968        let url = client
1969            .http_client()
1970            .build_zed_llm_url("/predict_edits/v3", &[])?;
1971
1972        let request = PredictEditsV3Request {
1973            input,
1974            model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1975            prompt_version,
1976            trigger,
1977        };
1978
1979        Self::send_api_request(
1980            |builder| {
1981                let req = builder
1982                    .uri(url.as_ref())
1983                    .body(serde_json::to_string(&request)?.into());
1984                Ok(req?)
1985            },
1986            client,
1987            llm_token,
1988            app_version,
1989            true,
1990        )
1991        .await
1992    }
1993
1994    fn handle_api_response<T>(
1995        this: &WeakEntity<Self>,
1996        response: Result<(T, Option<EditPredictionUsage>)>,
1997        cx: &mut gpui::AsyncApp,
1998    ) -> Result<T> {
1999        match response {
2000            Ok((data, usage)) => {
2001                if let Some(usage) = usage {
2002                    this.update(cx, |this, cx| {
2003                        this.user_store.update(cx, |user_store, cx| {
2004                            user_store.update_edit_prediction_usage(usage, cx);
2005                        });
2006                    })
2007                    .ok();
2008                }
2009                Ok(data)
2010            }
2011            Err(err) => {
2012                if err.is::<ZedUpdateRequiredError>() {
2013                    cx.update(|cx| {
2014                        this.update(cx, |this, _cx| {
2015                            this.update_required = true;
2016                        })
2017                        .ok();
2018
2019                        let error_message: SharedString = err.to_string().into();
2020                        show_app_notification(
2021                            NotificationId::unique::<ZedUpdateRequiredError>(),
2022                            cx,
2023                            move |cx| {
2024                                cx.new(|cx| {
2025                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2026                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2027                                })
2028                            },
2029                        );
2030                    });
2031                }
2032                Err(err)
2033            }
2034        }
2035    }
2036
2037    async fn send_api_request<Res>(
2038        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2039        client: Arc<Client>,
2040        llm_token: LlmApiToken,
2041        app_version: Version,
2042        require_auth: bool,
2043    ) -> Result<(Res, Option<EditPredictionUsage>)>
2044    where
2045        Res: DeserializeOwned,
2046    {
2047        let http_client = client.http_client();
2048
2049        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2050            Some(custom_token)
2051        } else if require_auth {
2052            Some(llm_token.acquire(&client).await?)
2053        } else {
2054            llm_token.acquire(&client).await.ok()
2055        };
2056        let mut did_retry = false;
2057
2058        loop {
2059            let request_builder = http_client::Request::builder().method(Method::POST);
2060
2061            let mut request_builder = request_builder
2062                .header("Content-Type", "application/json")
2063                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2064
2065            // Only add Authorization header if we have a token
2066            if let Some(ref token_value) = token {
2067                request_builder =
2068                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2069            }
2070
2071            let request = build(request_builder)?;
2072
2073            let mut response = http_client.send(request).await?;
2074
2075            if let Some(minimum_required_version) = response
2076                .headers()
2077                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2078                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2079            {
2080                anyhow::ensure!(
2081                    app_version >= minimum_required_version,
2082                    ZedUpdateRequiredError {
2083                        minimum_version: minimum_required_version
2084                    }
2085                );
2086            }
2087
2088            if response.status().is_success() {
2089                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2090
2091                let mut body = Vec::new();
2092                response.body_mut().read_to_end(&mut body).await?;
2093                return Ok((serde_json::from_slice(&body)?, usage));
2094            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2095                did_retry = true;
2096                token = Some(llm_token.refresh(&client).await?);
2097            } else {
2098                let mut body = String::new();
2099                response.body_mut().read_to_string(&mut body).await?;
2100                anyhow::bail!(
2101                    "Request failed with status: {:?}\nBody: {}",
2102                    response.status(),
2103                    body
2104                );
2105            }
2106        }
2107    }
2108
2109    pub fn refresh_context(
2110        &mut self,
2111        project: &Entity<Project>,
2112        buffer: &Entity<language::Buffer>,
2113        cursor_position: language::Anchor,
2114        cx: &mut Context<Self>,
2115    ) {
2116        self.get_or_init_project(project, cx)
2117            .context
2118            .update(cx, |store, cx| {
2119                store.refresh(buffer.clone(), cursor_position, cx);
2120            });
2121    }
2122
2123    #[cfg(feature = "cli-support")]
2124    pub fn set_context_for_buffer(
2125        &mut self,
2126        project: &Entity<Project>,
2127        related_files: Vec<RelatedFile>,
2128        cx: &mut Context<Self>,
2129    ) {
2130        self.get_or_init_project(project, cx)
2131            .context
2132            .update(cx, |store, cx| {
2133                store.set_related_files(related_files, cx);
2134            });
2135    }
2136
2137    #[cfg(feature = "cli-support")]
2138    pub fn set_recent_paths_for_project(
2139        &mut self,
2140        project: &Entity<Project>,
2141        paths: impl IntoIterator<Item = project::ProjectPath>,
2142        cx: &mut Context<Self>,
2143    ) {
2144        let project_state = self.get_or_init_project(project, cx);
2145        project_state.recent_paths = paths.into_iter().collect();
2146    }
2147
2148    fn is_file_open_source(
2149        &self,
2150        project: &Entity<Project>,
2151        file: &Arc<dyn File>,
2152        cx: &App,
2153    ) -> bool {
2154        if !file.is_local() || file.is_private() {
2155            return false;
2156        }
2157        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2158            return false;
2159        };
2160        project_state
2161            .license_detection_watchers
2162            .get(&file.worktree_id(cx))
2163            .as_ref()
2164            .is_some_and(|watcher| watcher.is_project_open_source())
2165    }
2166
2167    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2168        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2169    }
2170
2171    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2172        if !self.data_collection_choice.is_enabled(cx) {
2173            return false;
2174        }
2175        events.iter().all(|event| {
2176            matches!(
2177                event.as_ref(),
2178                zeta_prompt::Event::BufferChange {
2179                    in_open_source_repo: true,
2180                    ..
2181                }
2182            )
2183        })
2184    }
2185
2186    fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2187        if !self.data_collection_choice.is_enabled(cx) {
2188            return false;
2189        }
2190
2191        let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2192
2193        related_with_buffers.iter().all(|(_, buffer)| {
2194            buffer
2195                .read(cx)
2196                .file()
2197                .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2198        })
2199    }
2200
2201    fn load_data_collection_choice() -> DataCollectionChoice {
2202        let choice = KEY_VALUE_STORE
2203            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2204            .log_err()
2205            .flatten();
2206
2207        match choice.as_deref() {
2208            Some("true") => DataCollectionChoice::Enabled,
2209            Some("false") => DataCollectionChoice::Disabled,
2210            Some(_) => {
2211                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2212                DataCollectionChoice::NotAnswered
2213            }
2214            None => DataCollectionChoice::NotAnswered,
2215        }
2216    }
2217
2218    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2219        self.data_collection_choice = self.data_collection_choice.toggle();
2220        let new_choice = self.data_collection_choice;
2221        let is_enabled = new_choice.is_enabled(cx);
2222        db::write_and_log(cx, move || {
2223            KEY_VALUE_STORE.write_kvp(
2224                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2225                is_enabled.to_string(),
2226            )
2227        });
2228    }
2229
2230    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2231        self.shown_predictions.iter()
2232    }
2233
2234    pub fn shown_completions_len(&self) -> usize {
2235        self.shown_predictions.len()
2236    }
2237
2238    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2239        self.rated_predictions.contains(id)
2240    }
2241
2242    pub fn rate_prediction(
2243        &mut self,
2244        prediction: &EditPrediction,
2245        rating: EditPredictionRating,
2246        feedback: String,
2247        cx: &mut Context<Self>,
2248    ) {
2249        self.rated_predictions.insert(prediction.id.clone());
2250        telemetry::event!(
2251            "Edit Prediction Rated",
2252            rating,
2253            inputs = prediction.inputs,
2254            output = prediction
2255                .edit_preview
2256                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2257            feedback
2258        );
2259        self.client.telemetry().flush_events().detach();
2260        cx.notify();
2261    }
2262}
2263
2264pub(crate) fn filter_redundant_excerpts(
2265    mut related_files: Vec<RelatedFile>,
2266    cursor_path: &Path,
2267    cursor_row_range: Range<u32>,
2268) -> Vec<RelatedFile> {
2269    for file in &mut related_files {
2270        if file.path.as_ref() == cursor_path {
2271            file.excerpts.retain(|excerpt| {
2272                excerpt.row_range.start < cursor_row_range.start
2273                    || excerpt.row_range.end > cursor_row_range.end
2274            });
2275        }
2276    }
2277    related_files.retain(|file| !file.excerpts.is_empty());
2278    related_files
2279}
2280
2281#[derive(Error, Debug)]
2282#[error(
2283    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2284)]
2285pub struct ZedUpdateRequiredError {
2286    minimum_version: Version,
2287}
2288
2289#[derive(Debug, Clone, Copy)]
2290pub enum DataCollectionChoice {
2291    NotAnswered,
2292    Enabled,
2293    Disabled,
2294}
2295
2296impl DataCollectionChoice {
2297    pub fn is_enabled(self, cx: &App) -> bool {
2298        if cx.is_staff() {
2299            return true;
2300        }
2301        match self {
2302            Self::Enabled => true,
2303            Self::NotAnswered | Self::Disabled => false,
2304        }
2305    }
2306
2307    #[must_use]
2308    pub fn toggle(&self) -> DataCollectionChoice {
2309        match self {
2310            Self::Enabled => Self::Disabled,
2311            Self::Disabled => Self::Enabled,
2312            Self::NotAnswered => Self::Enabled,
2313        }
2314    }
2315}
2316
2317impl From<bool> for DataCollectionChoice {
2318    fn from(value: bool) -> Self {
2319        match value {
2320            true => DataCollectionChoice::Enabled,
2321            false => DataCollectionChoice::Disabled,
2322        }
2323    }
2324}
2325
2326struct ZedPredictUpsell;
2327
2328impl Dismissable for ZedPredictUpsell {
2329    const KEY: &'static str = "dismissed-edit-predict-upsell";
2330
2331    fn dismissed() -> bool {
2332        // To make this backwards compatible with older versions of Zed, we
2333        // check if the user has seen the previous Edit Prediction Onboarding
2334        // before, by checking the data collection choice which was written to
2335        // the database once the user clicked on "Accept and Enable"
2336        if KEY_VALUE_STORE
2337            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2338            .log_err()
2339            .is_some_and(|s| s.is_some())
2340        {
2341            return true;
2342        }
2343
2344        KEY_VALUE_STORE
2345            .read_kvp(Self::KEY)
2346            .log_err()
2347            .is_some_and(|s| s.is_some())
2348    }
2349}
2350
2351pub fn should_show_upsell_modal() -> bool {
2352    !ZedPredictUpsell::dismissed()
2353}
2354
2355pub fn init(cx: &mut App) {
2356    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2357        workspace.register_action(
2358            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2359                ZedPredictModal::toggle(
2360                    workspace,
2361                    workspace.user_store().clone(),
2362                    workspace.client().clone(),
2363                    window,
2364                    cx,
2365                )
2366            },
2367        );
2368
2369        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2370            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2371                settings
2372                    .project
2373                    .all_languages
2374                    .edit_predictions
2375                    .get_or_insert_default()
2376                    .provider = Some(EditPredictionProvider::None)
2377            });
2378        });
2379        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2380            EditPredictionStore::try_global(cx).and_then(|store| {
2381                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2382            })
2383        }
2384
2385        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2386            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2387                copilot_ui::initiate_sign_in(copilot, window, cx);
2388            }
2389        });
2390        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2391            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2392                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2393            }
2394        });
2395        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2396            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2397                copilot_ui::initiate_sign_out(copilot, window, cx);
2398            }
2399        });
2400    })
2401    .detach();
2402}