edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::Copilot;
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::EditPredictionExcerptOptions;
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, RefreshLlmTokenListener};
  34use project::{Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  39use std::collections::{VecDeque, hash_map};
  40use text::Edit;
  41use workspace::Workspace;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59mod onboarding_modal;
  60pub mod open_ai_response;
  61mod prediction;
  62pub mod sweep_ai;
  63
  64pub mod udiff;
  65
  66mod capture_example;
  67mod zed_edit_prediction_delegate;
  68pub mod zeta1;
  69pub mod zeta2;
  70
  71#[cfg(test)]
  72mod edit_prediction_tests;
  73
  74use crate::capture_example::should_sample_edit_prediction_example_capture;
  75use crate::license_detection::LicenseDetectionWatcher;
  76use crate::mercury::Mercury;
  77use crate::onboarding_modal::ZedPredictModal;
  78pub use crate::prediction::EditPrediction;
  79pub use crate::prediction::EditPredictionId;
  80use crate::prediction::EditPredictionResult;
  81pub use crate::sweep_ai::SweepAi;
  82pub use capture_example::capture_example;
  83pub use language_model::ApiKeyState;
  84pub use telemetry_events::EditPredictionRating;
  85pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  86
  87actions!(
  88    edit_prediction,
  89    [
  90        /// Resets the edit prediction onboarding state.
  91        ResetOnboarding,
  92        /// Clears the edit prediction history.
  93        ClearHistory,
  94    ]
  95);
  96
  97/// Maximum number of events to track.
  98const EVENT_COUNT_MAX: usize = 6;
  99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 100const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 102const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 103
 104pub struct SweepFeatureFlag;
 105
 106impl FeatureFlag for SweepFeatureFlag {
 107    const NAME: &str = "sweep-ai";
 108}
 109
 110pub struct MercuryFeatureFlag;
 111
 112impl FeatureFlag for MercuryFeatureFlag {
 113    const NAME: &str = "mercury";
 114}
 115
 116pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 117    context: EditPredictionExcerptOptions {
 118        max_bytes: 512,
 119        min_bytes: 128,
 120        target_before_cursor_over_total_bytes: 0.5,
 121    },
 122    prompt_format: PromptFormat::DEFAULT,
 123};
 124
 125static USE_OLLAMA: LazyLock<bool> =
 126    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 127
 128static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 129    match env::var("ZED_ZETA2_MODEL").as_deref() {
 130        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 131        Ok(model) => model,
 132        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 133        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 134    }
 135    .to_string()
 136});
 137
 138pub struct Zeta2FeatureFlag;
 139
 140impl FeatureFlag for Zeta2FeatureFlag {
 141    const NAME: &'static str = "zeta2";
 142
 143    fn enabled_for_staff() -> bool {
 144        true
 145    }
 146}
 147
 148pub struct EditPredictionExampleCaptureFeatureFlag;
 149
 150impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 151    const NAME: &'static str = "edit-prediction-example-capture";
 152
 153    fn enabled_for_staff() -> bool {
 154        true
 155    }
 156}
 157
 158#[derive(Clone)]
 159struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 160
 161impl Global for EditPredictionStoreGlobal {}
 162
 163pub struct EditPredictionStore {
 164    client: Arc<Client>,
 165    user_store: Entity<UserStore>,
 166    llm_token: LlmApiToken,
 167    _llm_token_subscription: Subscription,
 168    projects: HashMap<EntityId, ProjectState>,
 169    use_context: bool,
 170    options: ZetaOptions,
 171    update_required: bool,
 172    #[cfg(feature = "cli-support")]
 173    eval_cache: Option<Arc<dyn EvalCache>>,
 174    edit_prediction_model: EditPredictionModel,
 175    pub sweep_ai: SweepAi,
 176    pub mercury: Mercury,
 177    data_collection_choice: DataCollectionChoice,
 178    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 179    shown_predictions: VecDeque<EditPrediction>,
 180    rated_predictions: HashSet<EditPredictionId>,
 181    custom_predict_edits_url: Option<Arc<Url>>,
 182}
 183
 184#[derive(Copy, Clone, Default, PartialEq, Eq)]
 185pub enum EditPredictionModel {
 186    #[default]
 187    Zeta1,
 188    Zeta2 {
 189        version: ZetaVersion,
 190    },
 191    Sweep,
 192    Mercury,
 193}
 194
 195pub struct EditPredictionModelInput {
 196    project: Entity<Project>,
 197    buffer: Entity<Buffer>,
 198    snapshot: BufferSnapshot,
 199    position: Anchor,
 200    events: Vec<Arc<zeta_prompt::Event>>,
 201    related_files: Vec<RelatedFile>,
 202    recent_paths: VecDeque<ProjectPath>,
 203    trigger: PredictEditsRequestTrigger,
 204    diagnostic_search_range: Range<Point>,
 205    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 206    pub user_actions: Vec<UserActionRecord>,
 207}
 208
 209#[derive(Debug, Clone, PartialEq)]
 210pub struct ZetaOptions {
 211    pub context: EditPredictionExcerptOptions,
 212    pub prompt_format: predict_edits_v3::PromptFormat,
 213}
 214
 215#[derive(Debug)]
 216pub enum DebugEvent {
 217    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 218    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 219    EditPredictionStarted(EditPredictionStartedDebugEvent),
 220    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 221}
 222
 223#[derive(Debug)]
 224pub struct ContextRetrievalStartedDebugEvent {
 225    pub project_entity_id: EntityId,
 226    pub timestamp: Instant,
 227    pub search_prompt: String,
 228}
 229
 230#[derive(Debug)]
 231pub struct ContextRetrievalFinishedDebugEvent {
 232    pub project_entity_id: EntityId,
 233    pub timestamp: Instant,
 234    pub metadata: Vec<(&'static str, SharedString)>,
 235}
 236
 237#[derive(Debug)]
 238pub struct EditPredictionStartedDebugEvent {
 239    pub buffer: WeakEntity<Buffer>,
 240    pub position: Anchor,
 241    pub prompt: Option<String>,
 242}
 243
 244#[derive(Debug)]
 245pub struct EditPredictionFinishedDebugEvent {
 246    pub buffer: WeakEntity<Buffer>,
 247    pub position: Anchor,
 248    pub model_output: Option<String>,
 249}
 250
 251pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 252
 253const USER_ACTION_HISTORY_SIZE: usize = 16;
 254
 255#[derive(Clone, Debug)]
 256pub struct UserActionRecord {
 257    pub action_type: UserActionType,
 258    pub buffer_id: EntityId,
 259    pub line_number: u32,
 260    pub offset: usize,
 261    pub timestamp_epoch_ms: u64,
 262}
 263
 264#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 265pub enum UserActionType {
 266    InsertChar,
 267    InsertSelection,
 268    DeleteChar,
 269    DeleteSelection,
 270    CursorMovement,
 271}
 272
 273/// An event with associated metadata for reconstructing buffer state.
 274#[derive(Clone)]
 275pub struct StoredEvent {
 276    pub event: Arc<zeta_prompt::Event>,
 277    pub old_snapshot: TextBufferSnapshot,
 278}
 279
 280struct ProjectState {
 281    events: VecDeque<StoredEvent>,
 282    last_event: Option<LastEvent>,
 283    recent_paths: VecDeque<ProjectPath>,
 284    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 285    current_prediction: Option<CurrentEditPrediction>,
 286    next_pending_prediction_id: usize,
 287    pending_predictions: ArrayVec<PendingPrediction, 2>,
 288    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 289    last_prediction_refresh: Option<(EntityId, Instant)>,
 290    cancelled_predictions: HashSet<usize>,
 291    context: Entity<RelatedExcerptStore>,
 292    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 293    user_actions: VecDeque<UserActionRecord>,
 294    _subscription: gpui::Subscription,
 295    copilot: Option<Entity<Copilot>>,
 296}
 297
 298impl ProjectState {
 299    fn record_user_action(&mut self, action: UserActionRecord) {
 300        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 301            self.user_actions.pop_front();
 302        }
 303        self.user_actions.push_back(action);
 304    }
 305
 306    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 307        self.events
 308            .iter()
 309            .cloned()
 310            .chain(
 311                self.last_event
 312                    .as_ref()
 313                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 314            )
 315            .collect()
 316    }
 317
 318    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 319        self.events
 320            .iter()
 321            .cloned()
 322            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 323                let (one, two) = event.split_by_pause();
 324                let one = one.finalize(&self.license_detection_watchers, cx);
 325                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 326                one.into_iter().chain(two)
 327            }))
 328            .collect()
 329    }
 330
 331    fn cancel_pending_prediction(
 332        &mut self,
 333        pending_prediction: PendingPrediction,
 334        cx: &mut Context<EditPredictionStore>,
 335    ) {
 336        self.cancelled_predictions.insert(pending_prediction.id);
 337
 338        cx.spawn(async move |this, cx| {
 339            let Some(prediction_id) = pending_prediction.task.await else {
 340                return;
 341            };
 342
 343            this.update(cx, |this, _cx| {
 344                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 345            })
 346            .ok();
 347        })
 348        .detach()
 349    }
 350
 351    fn active_buffer(
 352        &self,
 353        project: &Entity<Project>,
 354        cx: &App,
 355    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 356        let project = project.read(cx);
 357        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 358        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 359        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 360        Some((active_buffer, registered_buffer.last_position))
 361    }
 362}
 363
 364#[derive(Debug, Clone)]
 365struct CurrentEditPrediction {
 366    pub requested_by: PredictionRequestedBy,
 367    pub prediction: EditPrediction,
 368    pub was_shown: bool,
 369    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 370}
 371
 372impl CurrentEditPrediction {
 373    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 374        let Some(new_edits) = self
 375            .prediction
 376            .interpolate(&self.prediction.buffer.read(cx))
 377        else {
 378            return false;
 379        };
 380
 381        if self.prediction.buffer != old_prediction.prediction.buffer {
 382            return true;
 383        }
 384
 385        let Some(old_edits) = old_prediction
 386            .prediction
 387            .interpolate(&old_prediction.prediction.buffer.read(cx))
 388        else {
 389            return true;
 390        };
 391
 392        let requested_by_buffer_id = self.requested_by.buffer_id();
 393
 394        // This reduces the occurrence of UI thrash from replacing edits
 395        //
 396        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 397        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 398            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 399            && old_edits.len() == 1
 400            && new_edits.len() == 1
 401        {
 402            let (old_range, old_text) = &old_edits[0];
 403            let (new_range, new_text) = &new_edits[0];
 404            new_range == old_range && new_text.starts_with(old_text.as_ref())
 405        } else {
 406            true
 407        }
 408    }
 409}
 410
 411#[derive(Debug, Clone)]
 412enum PredictionRequestedBy {
 413    DiagnosticsUpdate,
 414    Buffer(EntityId),
 415}
 416
 417impl PredictionRequestedBy {
 418    pub fn buffer_id(&self) -> Option<EntityId> {
 419        match self {
 420            PredictionRequestedBy::DiagnosticsUpdate => None,
 421            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 422        }
 423    }
 424}
 425
 426#[derive(Debug)]
 427struct PendingPrediction {
 428    id: usize,
 429    task: Task<Option<EditPredictionId>>,
 430}
 431
 432/// A prediction from the perspective of a buffer.
 433#[derive(Debug)]
 434enum BufferEditPrediction<'a> {
 435    Local { prediction: &'a EditPrediction },
 436    Jump { prediction: &'a EditPrediction },
 437}
 438
 439#[cfg(test)]
 440impl std::ops::Deref for BufferEditPrediction<'_> {
 441    type Target = EditPrediction;
 442
 443    fn deref(&self) -> &Self::Target {
 444        match self {
 445            BufferEditPrediction::Local { prediction } => prediction,
 446            BufferEditPrediction::Jump { prediction } => prediction,
 447        }
 448    }
 449}
 450
 451struct RegisteredBuffer {
 452    file: Option<Arc<dyn File>>,
 453    snapshot: TextBufferSnapshot,
 454    last_position: Option<Anchor>,
 455    _subscriptions: [gpui::Subscription; 2],
 456}
 457
 458#[derive(Clone)]
 459struct LastEvent {
 460    old_snapshot: TextBufferSnapshot,
 461    new_snapshot: TextBufferSnapshot,
 462    old_file: Option<Arc<dyn File>>,
 463    new_file: Option<Arc<dyn File>>,
 464    edit_range: Option<Range<Anchor>>,
 465    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 466    last_edit_time: Option<Instant>,
 467}
 468
 469impl LastEvent {
 470    pub fn finalize(
 471        &self,
 472        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 473        cx: &App,
 474    ) -> Option<StoredEvent> {
 475        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 476        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 477
 478        let in_open_source_repo =
 479            [self.new_file.as_ref(), self.old_file.as_ref()]
 480                .iter()
 481                .all(|file| {
 482                    file.is_some_and(|file| {
 483                        license_detection_watchers
 484                            .get(&file.worktree_id(cx))
 485                            .is_some_and(|watcher| watcher.is_project_open_source())
 486                    })
 487                });
 488
 489        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 490
 491        if path == old_path && diff.is_empty() {
 492            None
 493        } else {
 494            Some(StoredEvent {
 495                event: Arc::new(zeta_prompt::Event::BufferChange {
 496                    old_path,
 497                    path,
 498                    diff,
 499                    in_open_source_repo,
 500                    // TODO: Actually detect if this edit was predicted or not
 501                    predicted: false,
 502                }),
 503                old_snapshot: self.old_snapshot.clone(),
 504            })
 505        }
 506    }
 507
 508    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 509        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 510            return (self.clone(), None);
 511        };
 512
 513        let before = LastEvent {
 514            old_snapshot: self.old_snapshot.clone(),
 515            new_snapshot: boundary_snapshot.clone(),
 516            old_file: self.old_file.clone(),
 517            new_file: self.new_file.clone(),
 518            edit_range: None,
 519            snapshot_after_last_editing_pause: None,
 520            last_edit_time: self.last_edit_time,
 521        };
 522
 523        let after = LastEvent {
 524            old_snapshot: boundary_snapshot.clone(),
 525            new_snapshot: self.new_snapshot.clone(),
 526            old_file: self.old_file.clone(),
 527            new_file: self.new_file.clone(),
 528            edit_range: None,
 529            snapshot_after_last_editing_pause: None,
 530            last_edit_time: self.last_edit_time,
 531        };
 532
 533        (before, Some(after))
 534    }
 535}
 536
 537pub(crate) fn compute_diff_between_snapshots(
 538    old_snapshot: &TextBufferSnapshot,
 539    new_snapshot: &TextBufferSnapshot,
 540) -> Option<String> {
 541    let edits: Vec<Edit<usize>> = new_snapshot
 542        .edits_since::<usize>(&old_snapshot.version)
 543        .collect();
 544
 545    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 546
 547    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 548    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 549    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 550    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 551
 552    const CONTEXT_LINES: u32 = 3;
 553
 554    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 555    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 556    let old_context_end_row =
 557        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 558    let new_context_end_row =
 559        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 560
 561    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 562    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 563    let old_end_line_offset = old_snapshot
 564        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 565    let new_end_line_offset = new_snapshot
 566        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 567    let old_edit_range = old_start_line_offset..old_end_line_offset;
 568    let new_edit_range = new_start_line_offset..new_end_line_offset;
 569
 570    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 571    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 572
 573    let diff = language::unified_diff_with_offsets(
 574        &old_region_text,
 575        &new_region_text,
 576        old_context_start_row,
 577        new_context_start_row,
 578    );
 579
 580    Some(diff)
 581}
 582
 583fn buffer_path_with_id_fallback(
 584    file: Option<&Arc<dyn File>>,
 585    snapshot: &TextBufferSnapshot,
 586    cx: &App,
 587) -> Arc<Path> {
 588    if let Some(file) = file {
 589        file.full_path(cx).into()
 590    } else {
 591        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 592    }
 593}
 594
 595impl EditPredictionStore {
 596    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 597        cx.try_global::<EditPredictionStoreGlobal>()
 598            .map(|global| global.0.clone())
 599    }
 600
 601    pub fn global(
 602        client: &Arc<Client>,
 603        user_store: &Entity<UserStore>,
 604        cx: &mut App,
 605    ) -> Entity<Self> {
 606        cx.try_global::<EditPredictionStoreGlobal>()
 607            .map(|global| global.0.clone())
 608            .unwrap_or_else(|| {
 609                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 610                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 611                ep_store
 612            })
 613    }
 614
 615    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 616        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 617        let data_collection_choice = Self::load_data_collection_choice();
 618
 619        let llm_token = LlmApiToken::default();
 620
 621        let (reject_tx, reject_rx) = mpsc::unbounded();
 622        cx.background_spawn({
 623            let client = client.clone();
 624            let llm_token = llm_token.clone();
 625            let app_version = AppVersion::global(cx);
 626            let background_executor = cx.background_executor().clone();
 627            async move {
 628                Self::handle_rejected_predictions(
 629                    reject_rx,
 630                    client,
 631                    llm_token,
 632                    app_version,
 633                    background_executor,
 634                )
 635                .await
 636            }
 637        })
 638        .detach();
 639
 640        let mut this = Self {
 641            projects: HashMap::default(),
 642            client,
 643            user_store,
 644            options: DEFAULT_OPTIONS,
 645            use_context: false,
 646            llm_token,
 647            _llm_token_subscription: cx.subscribe(
 648                &refresh_llm_token_listener,
 649                |this, _listener, _event, cx| {
 650                    let client = this.client.clone();
 651                    let llm_token = this.llm_token.clone();
 652                    cx.spawn(async move |_this, _cx| {
 653                        llm_token.refresh(&client).await?;
 654                        anyhow::Ok(())
 655                    })
 656                    .detach_and_log_err(cx);
 657                },
 658            ),
 659            update_required: false,
 660            #[cfg(feature = "cli-support")]
 661            eval_cache: None,
 662            edit_prediction_model: EditPredictionModel::Zeta2 {
 663                version: Default::default(),
 664            },
 665            sweep_ai: SweepAi::new(cx),
 666            mercury: Mercury::new(cx),
 667
 668            data_collection_choice,
 669            reject_predictions_tx: reject_tx,
 670            rated_predictions: Default::default(),
 671            shown_predictions: Default::default(),
 672            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 673                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 674                Err(_) => {
 675                    if *USE_OLLAMA {
 676                        Some(
 677                            Url::parse("http://localhost:11434/v1/chat/completions")
 678                                .unwrap()
 679                                .into(),
 680                        )
 681                    } else {
 682                        None
 683                    }
 684                }
 685            },
 686        };
 687
 688        this.configure_context_retrieval(cx);
 689        let weak_this = cx.weak_entity();
 690        cx.on_flags_ready(move |_, cx| {
 691            weak_this
 692                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 693                .ok();
 694        })
 695        .detach();
 696        cx.observe_global::<SettingsStore>(|this, cx| {
 697            this.configure_context_retrieval(cx);
 698        })
 699        .detach();
 700
 701        this
 702    }
 703
 704    #[cfg(test)]
 705    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 706        self.custom_predict_edits_url = Some(url.into());
 707    }
 708
 709    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 710        self.edit_prediction_model = model;
 711    }
 712
 713    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 714        self.sweep_ai.api_token.read(cx).has_key()
 715    }
 716
 717    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 718        self.mercury.api_token.read(cx).has_key()
 719    }
 720
 721    #[cfg(feature = "cli-support")]
 722    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 723        self.eval_cache = Some(cache);
 724    }
 725
 726    pub fn options(&self) -> &ZetaOptions {
 727        &self.options
 728    }
 729
 730    pub fn set_options(&mut self, options: ZetaOptions) {
 731        self.options = options;
 732    }
 733
 734    pub fn set_use_context(&mut self, use_context: bool) {
 735        self.use_context = use_context;
 736    }
 737
 738    pub fn clear_history(&mut self) {
 739        for project_state in self.projects.values_mut() {
 740            project_state.events.clear();
 741            project_state.last_event.take();
 742        }
 743    }
 744
 745    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 746        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 747            project_state.events.clear();
 748            project_state.last_event.take();
 749        }
 750    }
 751
 752    pub fn edit_history_for_project(
 753        &self,
 754        project: &Entity<Project>,
 755        cx: &App,
 756    ) -> Vec<StoredEvent> {
 757        self.projects
 758            .get(&project.entity_id())
 759            .map(|project_state| project_state.events(cx))
 760            .unwrap_or_default()
 761    }
 762
 763    pub fn edit_history_for_project_with_pause_split_last_event(
 764        &self,
 765        project: &Entity<Project>,
 766        cx: &App,
 767    ) -> Vec<StoredEvent> {
 768        self.projects
 769            .get(&project.entity_id())
 770            .map(|project_state| project_state.events_split_by_pause(cx))
 771            .unwrap_or_default()
 772    }
 773
 774    pub fn context_for_project<'a>(
 775        &'a self,
 776        project: &Entity<Project>,
 777        cx: &'a mut App,
 778    ) -> Vec<RelatedFile> {
 779        self.projects
 780            .get(&project.entity_id())
 781            .map(|project| {
 782                project
 783                    .context
 784                    .update(cx, |context, cx| context.related_files(cx))
 785            })
 786            .unwrap_or_default()
 787    }
 788
 789    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 790        self.projects
 791            .get(&project.entity_id())
 792            .and_then(|project| project.copilot.clone())
 793    }
 794
 795    pub fn start_copilot_for_project(
 796        &mut self,
 797        project: &Entity<Project>,
 798        cx: &mut Context<Self>,
 799    ) -> Option<Entity<Copilot>> {
 800        let state = self.get_or_init_project(project, cx);
 801
 802        if state.copilot.is_some() {
 803            return state.copilot.clone();
 804        }
 805        let _project = project.clone();
 806        let project = project.read(cx);
 807
 808        let node = project.node_runtime().cloned();
 809        if let Some(node) = node {
 810            let next_id = project.languages().next_language_server_id();
 811            let fs = project.fs().clone();
 812
 813            let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx));
 814            state.copilot = Some(copilot.clone());
 815            Some(copilot)
 816        } else {
 817            None
 818        }
 819    }
 820
 821    pub fn context_for_project_with_buffers<'a>(
 822        &'a self,
 823        project: &Entity<Project>,
 824        cx: &'a mut App,
 825    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 826        self.projects
 827            .get(&project.entity_id())
 828            .map(|project| {
 829                project
 830                    .context
 831                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 832            })
 833            .unwrap_or_default()
 834    }
 835
 836    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 837        if matches!(
 838            self.edit_prediction_model,
 839            EditPredictionModel::Zeta2 { .. }
 840        ) {
 841            self.user_store.read(cx).edit_prediction_usage()
 842        } else {
 843            None
 844        }
 845    }
 846
 847    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 848        self.get_or_init_project(project, cx);
 849    }
 850
 851    pub fn register_buffer(
 852        &mut self,
 853        buffer: &Entity<Buffer>,
 854        project: &Entity<Project>,
 855        cx: &mut Context<Self>,
 856    ) {
 857        let project_state = self.get_or_init_project(project, cx);
 858        Self::register_buffer_impl(project_state, buffer, project, cx);
 859    }
 860
 861    fn get_or_init_project(
 862        &mut self,
 863        project: &Entity<Project>,
 864        cx: &mut Context<Self>,
 865    ) -> &mut ProjectState {
 866        let entity_id = project.entity_id();
 867        self.projects
 868            .entry(entity_id)
 869            .or_insert_with(|| ProjectState {
 870                context: {
 871                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 872                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 873                        this.handle_excerpt_store_event(entity_id, event);
 874                    })
 875                    .detach();
 876                    related_excerpt_store
 877                },
 878                events: VecDeque::new(),
 879                last_event: None,
 880                recent_paths: VecDeque::new(),
 881                debug_tx: None,
 882                registered_buffers: HashMap::default(),
 883                current_prediction: None,
 884                cancelled_predictions: HashSet::default(),
 885                pending_predictions: ArrayVec::new(),
 886                next_pending_prediction_id: 0,
 887                last_prediction_refresh: None,
 888                license_detection_watchers: HashMap::default(),
 889                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 890                _subscription: cx.subscribe(&project, Self::handle_project_event),
 891                copilot: None,
 892            })
 893    }
 894
 895    pub fn remove_project(&mut self, project: &Entity<Project>) {
 896        self.projects.remove(&project.entity_id());
 897    }
 898
 899    fn handle_excerpt_store_event(
 900        &mut self,
 901        project_entity_id: EntityId,
 902        event: &RelatedExcerptStoreEvent,
 903    ) {
 904        if let Some(project_state) = self.projects.get(&project_entity_id) {
 905            if let Some(debug_tx) = project_state.debug_tx.clone() {
 906                match event {
 907                    RelatedExcerptStoreEvent::StartedRefresh => {
 908                        debug_tx
 909                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 910                                ContextRetrievalStartedDebugEvent {
 911                                    project_entity_id: project_entity_id,
 912                                    timestamp: Instant::now(),
 913                                    search_prompt: String::new(),
 914                                },
 915                            ))
 916                            .ok();
 917                    }
 918                    RelatedExcerptStoreEvent::FinishedRefresh {
 919                        cache_hit_count,
 920                        cache_miss_count,
 921                        mean_definition_latency,
 922                        max_definition_latency,
 923                    } => {
 924                        debug_tx
 925                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 926                                ContextRetrievalFinishedDebugEvent {
 927                                    project_entity_id: project_entity_id,
 928                                    timestamp: Instant::now(),
 929                                    metadata: vec![
 930                                        (
 931                                            "Cache Hits",
 932                                            format!(
 933                                                "{}/{}",
 934                                                cache_hit_count,
 935                                                cache_hit_count + cache_miss_count
 936                                            )
 937                                            .into(),
 938                                        ),
 939                                        (
 940                                            "Max LSP Time",
 941                                            format!("{} ms", max_definition_latency.as_millis())
 942                                                .into(),
 943                                        ),
 944                                        (
 945                                            "Mean LSP Time",
 946                                            format!("{} ms", mean_definition_latency.as_millis())
 947                                                .into(),
 948                                        ),
 949                                    ],
 950                                },
 951                            ))
 952                            .ok();
 953                    }
 954                }
 955            }
 956        }
 957    }
 958
 959    pub fn debug_info(
 960        &mut self,
 961        project: &Entity<Project>,
 962        cx: &mut Context<Self>,
 963    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 964        let project_state = self.get_or_init_project(project, cx);
 965        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 966        project_state.debug_tx = Some(debug_watch_tx);
 967        debug_watch_rx
 968    }
 969
 970    fn handle_project_event(
 971        &mut self,
 972        project: Entity<Project>,
 973        event: &project::Event,
 974        cx: &mut Context<Self>,
 975    ) {
 976        // TODO [zeta2] init with recent paths
 977        match event {
 978            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 979                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 980                    return;
 981                };
 982                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 983                if let Some(path) = path {
 984                    if let Some(ix) = project_state
 985                        .recent_paths
 986                        .iter()
 987                        .position(|probe| probe == &path)
 988                    {
 989                        project_state.recent_paths.remove(ix);
 990                    }
 991                    project_state.recent_paths.push_front(path);
 992                }
 993            }
 994            project::Event::DiagnosticsUpdated { .. } => {
 995                if cx.has_flag::<Zeta2FeatureFlag>() {
 996                    self.refresh_prediction_from_diagnostics(project, cx);
 997                }
 998            }
 999            _ => (),
1000        }
1001    }
1002
1003    fn register_buffer_impl<'a>(
1004        project_state: &'a mut ProjectState,
1005        buffer: &Entity<Buffer>,
1006        project: &Entity<Project>,
1007        cx: &mut Context<Self>,
1008    ) -> &'a mut RegisteredBuffer {
1009        let buffer_id = buffer.entity_id();
1010
1011        if let Some(file) = buffer.read(cx).file() {
1012            let worktree_id = file.worktree_id(cx);
1013            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1014                project_state
1015                    .license_detection_watchers
1016                    .entry(worktree_id)
1017                    .or_insert_with(|| {
1018                        let project_entity_id = project.entity_id();
1019                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1020                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1021                            else {
1022                                return;
1023                            };
1024                            project_state
1025                                .license_detection_watchers
1026                                .remove(&worktree_id);
1027                        })
1028                        .detach();
1029                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1030                    });
1031            }
1032        }
1033
1034        match project_state.registered_buffers.entry(buffer_id) {
1035            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1036            hash_map::Entry::Vacant(entry) => {
1037                let buf = buffer.read(cx);
1038                let snapshot = buf.text_snapshot();
1039                let file = buf.file().cloned();
1040                let project_entity_id = project.entity_id();
1041                entry.insert(RegisteredBuffer {
1042                    snapshot,
1043                    file,
1044                    last_position: None,
1045                    _subscriptions: [
1046                        cx.subscribe(buffer, {
1047                            let project = project.downgrade();
1048                            move |this, buffer, event, cx| {
1049                                if let language::BufferEvent::Edited = event
1050                                    && let Some(project) = project.upgrade()
1051                                {
1052                                    this.report_changes_for_buffer(&buffer, &project, cx);
1053                                }
1054                            }
1055                        }),
1056                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1057                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1058                            else {
1059                                return;
1060                            };
1061                            project_state.registered_buffers.remove(&buffer_id);
1062                        }),
1063                    ],
1064                })
1065            }
1066        }
1067    }
1068
1069    fn report_changes_for_buffer(
1070        &mut self,
1071        buffer: &Entity<Buffer>,
1072        project: &Entity<Project>,
1073        cx: &mut Context<Self>,
1074    ) {
1075        let project_state = self.get_or_init_project(project, cx);
1076        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1077
1078        let buf = buffer.read(cx);
1079        let new_file = buf.file().cloned();
1080        let new_snapshot = buf.text_snapshot();
1081        if new_snapshot.version == registered_buffer.snapshot.version {
1082            return;
1083        }
1084
1085        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1086        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1087        let mut num_edits = 0usize;
1088        let mut total_deleted = 0usize;
1089        let mut total_inserted = 0usize;
1090        let mut edit_range: Option<Range<Anchor>> = None;
1091        let mut last_offset: Option<usize> = None;
1092
1093        for (edit, anchor_range) in
1094            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1095        {
1096            num_edits += 1;
1097            total_deleted += edit.old.len();
1098            total_inserted += edit.new.len();
1099            edit_range = Some(match edit_range {
1100                None => anchor_range,
1101                Some(acc) => acc.start..anchor_range.end,
1102            });
1103            last_offset = Some(edit.new.end);
1104        }
1105
1106        if num_edits > 0 {
1107            let action_type = match (total_deleted, total_inserted, num_edits) {
1108                (0, ins, n) if ins == n => UserActionType::InsertChar,
1109                (0, _, _) => UserActionType::InsertSelection,
1110                (del, 0, n) if del == n => UserActionType::DeleteChar,
1111                (_, 0, _) => UserActionType::DeleteSelection,
1112                (_, ins, n) if ins == n => UserActionType::InsertChar,
1113                (_, _, _) => UserActionType::InsertSelection,
1114            };
1115
1116            if let Some(offset) = last_offset {
1117                let point = new_snapshot.offset_to_point(offset);
1118                let timestamp_epoch_ms = SystemTime::now()
1119                    .duration_since(UNIX_EPOCH)
1120                    .map(|d| d.as_millis() as u64)
1121                    .unwrap_or(0);
1122                project_state.record_user_action(UserActionRecord {
1123                    action_type,
1124                    buffer_id: buffer.entity_id(),
1125                    line_number: point.row,
1126                    offset,
1127                    timestamp_epoch_ms,
1128                });
1129            }
1130        }
1131
1132        let events = &mut project_state.events;
1133
1134        let now = cx.background_executor().now();
1135        if let Some(last_event) = project_state.last_event.as_mut() {
1136            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1137                == last_event.new_snapshot.remote_id()
1138                && old_snapshot.version == last_event.new_snapshot.version;
1139
1140            let should_coalesce = is_next_snapshot_of_same_buffer
1141                && edit_range
1142                    .as_ref()
1143                    .zip(last_event.edit_range.as_ref())
1144                    .is_some_and(|(a, b)| {
1145                        let a = a.to_point(&new_snapshot);
1146                        let b = b.to_point(&new_snapshot);
1147                        if a.start > b.end {
1148                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1149                        } else if b.start > a.end {
1150                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1151                        } else {
1152                            true
1153                        }
1154                    });
1155
1156            if should_coalesce {
1157                let pause_elapsed = last_event
1158                    .last_edit_time
1159                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1160                    .unwrap_or(false);
1161                if pause_elapsed {
1162                    last_event.snapshot_after_last_editing_pause =
1163                        Some(last_event.new_snapshot.clone());
1164                }
1165
1166                last_event.edit_range = edit_range;
1167                last_event.new_snapshot = new_snapshot;
1168                last_event.last_edit_time = Some(now);
1169                return;
1170            }
1171        }
1172
1173        if let Some(event) = project_state.last_event.take() {
1174            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1175                if events.len() + 1 >= EVENT_COUNT_MAX {
1176                    events.pop_front();
1177                }
1178                events.push_back(event);
1179            }
1180        }
1181
1182        project_state.last_event = Some(LastEvent {
1183            old_file,
1184            new_file,
1185            old_snapshot,
1186            new_snapshot,
1187            edit_range,
1188            snapshot_after_last_editing_pause: None,
1189            last_edit_time: Some(now),
1190        });
1191    }
1192
1193    fn prediction_at(
1194        &mut self,
1195        buffer: &Entity<Buffer>,
1196        position: Option<language::Anchor>,
1197        project: &Entity<Project>,
1198        cx: &App,
1199    ) -> Option<BufferEditPrediction<'_>> {
1200        let project_state = self.projects.get_mut(&project.entity_id())?;
1201        if let Some(position) = position
1202            && let Some(buffer) = project_state
1203                .registered_buffers
1204                .get_mut(&buffer.entity_id())
1205        {
1206            buffer.last_position = Some(position);
1207        }
1208
1209        let CurrentEditPrediction {
1210            requested_by,
1211            prediction,
1212            ..
1213        } = project_state.current_prediction.as_ref()?;
1214
1215        if prediction.targets_buffer(buffer.read(cx)) {
1216            Some(BufferEditPrediction::Local { prediction })
1217        } else {
1218            let show_jump = match requested_by {
1219                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1220                    requested_by_buffer_id == &buffer.entity_id()
1221                }
1222                PredictionRequestedBy::DiagnosticsUpdate => true,
1223            };
1224
1225            if show_jump {
1226                Some(BufferEditPrediction::Jump { prediction })
1227            } else {
1228                None
1229            }
1230        }
1231    }
1232
1233    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1234        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1235            return;
1236        };
1237
1238        let Some(current_prediction) = project_state.current_prediction.take() else {
1239            return;
1240        };
1241
1242        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1243            project_state.cancel_pending_prediction(pending_prediction, cx);
1244        }
1245
1246        match self.edit_prediction_model {
1247            EditPredictionModel::Sweep => {
1248                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1249            }
1250            EditPredictionModel::Mercury => {}
1251            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1252                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1253            }
1254        }
1255    }
1256
1257    async fn handle_rejected_predictions(
1258        rx: UnboundedReceiver<EditPredictionRejection>,
1259        client: Arc<Client>,
1260        llm_token: LlmApiToken,
1261        app_version: Version,
1262        background_executor: BackgroundExecutor,
1263    ) {
1264        let mut rx = std::pin::pin!(rx.peekable());
1265        let mut batched = Vec::new();
1266
1267        while let Some(rejection) = rx.next().await {
1268            batched.push(rejection);
1269
1270            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1271                select_biased! {
1272                    next = rx.as_mut().peek().fuse() => {
1273                        if next.is_some() {
1274                            continue;
1275                        }
1276                    }
1277                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1278                }
1279            }
1280
1281            let url = client
1282                .http_client()
1283                .build_zed_llm_url("/predict_edits/reject", &[])
1284                .unwrap();
1285
1286            let flush_count = batched
1287                .len()
1288                // in case items have accumulated after failure
1289                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1290            let start = batched.len() - flush_count;
1291
1292            let body = RejectEditPredictionsBodyRef {
1293                rejections: &batched[start..],
1294            };
1295
1296            let result = Self::send_api_request::<()>(
1297                |builder| {
1298                    let req = builder
1299                        .uri(url.as_ref())
1300                        .body(serde_json::to_string(&body)?.into());
1301                    anyhow::Ok(req?)
1302                },
1303                client.clone(),
1304                llm_token.clone(),
1305                app_version.clone(),
1306                true,
1307            )
1308            .await;
1309
1310            if result.log_err().is_some() {
1311                batched.drain(start..);
1312            }
1313        }
1314    }
1315
1316    fn reject_current_prediction(
1317        &mut self,
1318        reason: EditPredictionRejectReason,
1319        project: &Entity<Project>,
1320    ) {
1321        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1322            project_state.pending_predictions.clear();
1323            if let Some(prediction) = project_state.current_prediction.take() {
1324                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1325            }
1326        };
1327    }
1328
1329    fn did_show_current_prediction(
1330        &mut self,
1331        project: &Entity<Project>,
1332        display_type: edit_prediction_types::SuggestionDisplayType,
1333        cx: &mut Context<Self>,
1334    ) {
1335        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1336            return;
1337        };
1338
1339        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1340            return;
1341        };
1342
1343        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1344        let previous_shown_with = current_prediction.shown_with;
1345
1346        if previous_shown_with.is_none() || !is_jump {
1347            current_prediction.shown_with = Some(display_type);
1348        }
1349
1350        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1351
1352        if is_first_non_jump_show {
1353            current_prediction.was_shown = true;
1354        }
1355
1356        let display_type_changed = previous_shown_with != Some(display_type);
1357
1358        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1359            sweep_ai::edit_prediction_shown(
1360                &self.sweep_ai,
1361                self.client.clone(),
1362                &current_prediction.prediction,
1363                display_type,
1364                cx,
1365            );
1366        }
1367
1368        if is_first_non_jump_show {
1369            self.shown_predictions
1370                .push_front(current_prediction.prediction.clone());
1371            if self.shown_predictions.len() > 50 {
1372                let completion = self.shown_predictions.pop_back().unwrap();
1373                self.rated_predictions.remove(&completion.id);
1374            }
1375        }
1376    }
1377
1378    fn reject_prediction(
1379        &mut self,
1380        prediction_id: EditPredictionId,
1381        reason: EditPredictionRejectReason,
1382        was_shown: bool,
1383    ) {
1384        match self.edit_prediction_model {
1385            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1386                if self.custom_predict_edits_url.is_some() {
1387                    return;
1388                }
1389            }
1390            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1391        }
1392
1393        self.reject_predictions_tx
1394            .unbounded_send(EditPredictionRejection {
1395                request_id: prediction_id.to_string(),
1396                reason,
1397                was_shown,
1398            })
1399            .log_err();
1400    }
1401
1402    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1403        self.projects
1404            .get(&project.entity_id())
1405            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1406    }
1407
1408    pub fn refresh_prediction_from_buffer(
1409        &mut self,
1410        project: Entity<Project>,
1411        buffer: Entity<Buffer>,
1412        position: language::Anchor,
1413        cx: &mut Context<Self>,
1414    ) {
1415        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1416            let Some(request_task) = this
1417                .update(cx, |this, cx| {
1418                    this.request_prediction(
1419                        &project,
1420                        &buffer,
1421                        position,
1422                        PredictEditsRequestTrigger::Other,
1423                        cx,
1424                    )
1425                })
1426                .log_err()
1427            else {
1428                return Task::ready(anyhow::Ok(None));
1429            };
1430
1431            cx.spawn(async move |_cx| {
1432                request_task.await.map(|prediction_result| {
1433                    prediction_result.map(|prediction_result| {
1434                        (
1435                            prediction_result,
1436                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1437                        )
1438                    })
1439                })
1440            })
1441        })
1442    }
1443
1444    pub fn refresh_prediction_from_diagnostics(
1445        &mut self,
1446        project: Entity<Project>,
1447        cx: &mut Context<Self>,
1448    ) {
1449        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1450            return;
1451        };
1452
1453        // Prefer predictions from buffer
1454        if project_state.current_prediction.is_some() {
1455            return;
1456        };
1457
1458        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1459            let Some((active_buffer, snapshot, cursor_point)) = this
1460                .read_with(cx, |this, cx| {
1461                    let project_state = this.projects.get(&project.entity_id())?;
1462                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1463                    let snapshot = buffer.read(cx).snapshot();
1464
1465                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1466                        return None;
1467                    }
1468
1469                    let cursor_point = position
1470                        .map(|pos| pos.to_point(&snapshot))
1471                        .unwrap_or_default();
1472
1473                    Some((buffer, snapshot, cursor_point))
1474                })
1475                .log_err()
1476                .flatten()
1477            else {
1478                return Task::ready(anyhow::Ok(None));
1479            };
1480
1481            cx.spawn(async move |cx| {
1482                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1483                    active_buffer,
1484                    &snapshot,
1485                    Default::default(),
1486                    cursor_point,
1487                    &project,
1488                    cx,
1489                )
1490                .await?
1491                else {
1492                    return anyhow::Ok(None);
1493                };
1494
1495                let Some(prediction_result) = this
1496                    .update(cx, |this, cx| {
1497                        this.request_prediction(
1498                            &project,
1499                            &jump_buffer,
1500                            jump_position,
1501                            PredictEditsRequestTrigger::Diagnostics,
1502                            cx,
1503                        )
1504                    })?
1505                    .await?
1506                else {
1507                    return anyhow::Ok(None);
1508                };
1509
1510                this.update(cx, |this, cx| {
1511                    Some((
1512                        if this
1513                            .get_or_init_project(&project, cx)
1514                            .current_prediction
1515                            .is_none()
1516                        {
1517                            prediction_result
1518                        } else {
1519                            EditPredictionResult {
1520                                id: prediction_result.id,
1521                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1522                            }
1523                        },
1524                        PredictionRequestedBy::DiagnosticsUpdate,
1525                    ))
1526                })
1527            })
1528        });
1529    }
1530
1531    fn predictions_enabled_at(
1532        snapshot: &BufferSnapshot,
1533        position: Option<language::Anchor>,
1534        cx: &App,
1535    ) -> bool {
1536        let file = snapshot.file();
1537        let all_settings = all_language_settings(file, cx);
1538        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1539            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1540        {
1541            return false;
1542        }
1543
1544        if let Some(last_position) = position {
1545            let settings = snapshot.settings_at(last_position, cx);
1546
1547            if !settings.edit_predictions_disabled_in.is_empty()
1548                && let Some(scope) = snapshot.language_scope_at(last_position)
1549                && let Some(scope_name) = scope.override_name()
1550                && settings
1551                    .edit_predictions_disabled_in
1552                    .iter()
1553                    .any(|s| s == scope_name)
1554            {
1555                return false;
1556            }
1557        }
1558
1559        true
1560    }
1561
1562    #[cfg(not(test))]
1563    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1564    #[cfg(test)]
1565    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1566
1567    fn queue_prediction_refresh(
1568        &mut self,
1569        project: Entity<Project>,
1570        throttle_entity: EntityId,
1571        cx: &mut Context<Self>,
1572        do_refresh: impl FnOnce(
1573            WeakEntity<Self>,
1574            &mut AsyncApp,
1575        )
1576            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1577        + 'static,
1578    ) {
1579        let project_state = self.get_or_init_project(&project, cx);
1580        let pending_prediction_id = project_state.next_pending_prediction_id;
1581        project_state.next_pending_prediction_id += 1;
1582        let last_request = project_state.last_prediction_refresh;
1583
1584        let task = cx.spawn(async move |this, cx| {
1585            if let Some((last_entity, last_timestamp)) = last_request
1586                && throttle_entity == last_entity
1587                && let Some(timeout) =
1588                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1589            {
1590                cx.background_executor().timer(timeout).await;
1591            }
1592
1593            // If this task was cancelled before the throttle timeout expired,
1594            // do not perform a request.
1595            let mut is_cancelled = true;
1596            this.update(cx, |this, cx| {
1597                let project_state = this.get_or_init_project(&project, cx);
1598                if !project_state
1599                    .cancelled_predictions
1600                    .remove(&pending_prediction_id)
1601                {
1602                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1603                    is_cancelled = false;
1604                }
1605            })
1606            .ok();
1607            if is_cancelled {
1608                return None;
1609            }
1610
1611            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1612            let new_prediction_id = new_prediction_result
1613                .as_ref()
1614                .map(|(prediction, _)| prediction.id.clone());
1615
1616            // When a prediction completes, remove it from the pending list, and cancel
1617            // any pending predictions that were enqueued before it.
1618            this.update(cx, |this, cx| {
1619                let project_state = this.get_or_init_project(&project, cx);
1620
1621                let is_cancelled = project_state
1622                    .cancelled_predictions
1623                    .remove(&pending_prediction_id);
1624
1625                let new_current_prediction = if !is_cancelled
1626                    && let Some((prediction_result, requested_by)) = new_prediction_result
1627                {
1628                    match prediction_result.prediction {
1629                        Ok(prediction) => {
1630                            let new_prediction = CurrentEditPrediction {
1631                                requested_by,
1632                                prediction,
1633                                was_shown: false,
1634                                shown_with: None,
1635                            };
1636
1637                            if let Some(current_prediction) =
1638                                project_state.current_prediction.as_ref()
1639                            {
1640                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1641                                {
1642                                    this.reject_current_prediction(
1643                                        EditPredictionRejectReason::Replaced,
1644                                        &project,
1645                                    );
1646
1647                                    Some(new_prediction)
1648                                } else {
1649                                    this.reject_prediction(
1650                                        new_prediction.prediction.id,
1651                                        EditPredictionRejectReason::CurrentPreferred,
1652                                        false,
1653                                    );
1654                                    None
1655                                }
1656                            } else {
1657                                Some(new_prediction)
1658                            }
1659                        }
1660                        Err(reject_reason) => {
1661                            this.reject_prediction(prediction_result.id, reject_reason, false);
1662                            None
1663                        }
1664                    }
1665                } else {
1666                    None
1667                };
1668
1669                let project_state = this.get_or_init_project(&project, cx);
1670
1671                if let Some(new_prediction) = new_current_prediction {
1672                    project_state.current_prediction = Some(new_prediction);
1673                }
1674
1675                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1676                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1677                    if pending_prediction.id == pending_prediction_id {
1678                        pending_predictions.remove(ix);
1679                        for pending_prediction in pending_predictions.drain(0..ix) {
1680                            project_state.cancel_pending_prediction(pending_prediction, cx)
1681                        }
1682                        break;
1683                    }
1684                }
1685                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1686                cx.notify();
1687            })
1688            .ok();
1689
1690            new_prediction_id
1691        });
1692
1693        if project_state.pending_predictions.len() <= 1 {
1694            project_state.pending_predictions.push(PendingPrediction {
1695                id: pending_prediction_id,
1696                task,
1697            });
1698        } else if project_state.pending_predictions.len() == 2 {
1699            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1700            project_state.pending_predictions.push(PendingPrediction {
1701                id: pending_prediction_id,
1702                task,
1703            });
1704            project_state.cancel_pending_prediction(pending_prediction, cx);
1705        }
1706    }
1707
1708    pub fn request_prediction(
1709        &mut self,
1710        project: &Entity<Project>,
1711        active_buffer: &Entity<Buffer>,
1712        position: language::Anchor,
1713        trigger: PredictEditsRequestTrigger,
1714        cx: &mut Context<Self>,
1715    ) -> Task<Result<Option<EditPredictionResult>>> {
1716        self.request_prediction_internal(
1717            project.clone(),
1718            active_buffer.clone(),
1719            position,
1720            trigger,
1721            cx.has_flag::<Zeta2FeatureFlag>(),
1722            cx,
1723        )
1724    }
1725
1726    fn request_prediction_internal(
1727        &mut self,
1728        project: Entity<Project>,
1729        active_buffer: Entity<Buffer>,
1730        position: language::Anchor,
1731        trigger: PredictEditsRequestTrigger,
1732        allow_jump: bool,
1733        cx: &mut Context<Self>,
1734    ) -> Task<Result<Option<EditPredictionResult>>> {
1735        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1736
1737        self.get_or_init_project(&project, cx);
1738        let project_state = self.projects.get(&project.entity_id()).unwrap();
1739        let stored_events = project_state.events(cx);
1740        let has_events = !stored_events.is_empty();
1741        let events: Vec<Arc<zeta_prompt::Event>> =
1742            stored_events.into_iter().map(|e| e.event).collect();
1743        let debug_tx = project_state.debug_tx.clone();
1744
1745        let snapshot = active_buffer.read(cx).snapshot();
1746        let cursor_point = position.to_point(&snapshot);
1747        let current_offset = position.to_offset(&snapshot);
1748
1749        let mut user_actions: Vec<UserActionRecord> =
1750            project_state.user_actions.iter().cloned().collect();
1751
1752        if let Some(last_action) = user_actions.last() {
1753            if last_action.buffer_id == active_buffer.entity_id()
1754                && current_offset != last_action.offset
1755            {
1756                let timestamp_epoch_ms = SystemTime::now()
1757                    .duration_since(UNIX_EPOCH)
1758                    .map(|d| d.as_millis() as u64)
1759                    .unwrap_or(0);
1760                user_actions.push(UserActionRecord {
1761                    action_type: UserActionType::CursorMovement,
1762                    buffer_id: active_buffer.entity_id(),
1763                    line_number: cursor_point.row,
1764                    offset: current_offset,
1765                    timestamp_epoch_ms,
1766                });
1767            }
1768        }
1769        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1770        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1771        let diagnostic_search_range =
1772            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1773
1774        let related_files = if self.use_context {
1775            self.context_for_project(&project, cx)
1776        } else {
1777            Vec::new()
1778        };
1779
1780        let inputs = EditPredictionModelInput {
1781            project: project.clone(),
1782            buffer: active_buffer.clone(),
1783            snapshot: snapshot.clone(),
1784            position,
1785            events,
1786            related_files,
1787            recent_paths: project_state.recent_paths.clone(),
1788            trigger,
1789            diagnostic_search_range: diagnostic_search_range.clone(),
1790            debug_tx,
1791            user_actions,
1792        };
1793
1794        let can_collect_example = snapshot
1795            .file()
1796            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1797            && self.can_collect_events(&inputs.events, cx);
1798
1799        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1800            let events_for_capture =
1801                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1802            if let Some(example_task) = capture_example::capture_example(
1803                project.clone(),
1804                active_buffer.clone(),
1805                position,
1806                events_for_capture,
1807                false,
1808                cx,
1809            ) {
1810                cx.spawn(async move |_this, _cx| {
1811                    let example = example_task.await?;
1812                    telemetry::event!("Edit Prediction Example Captured", example = example);
1813                    anyhow::Ok(())
1814                })
1815                .detach_and_log_err(cx);
1816            }
1817        }
1818        let task = match self.edit_prediction_model {
1819            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1820            EditPredictionModel::Zeta2 { version } => {
1821                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1822            }
1823            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1824            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1825        };
1826
1827        cx.spawn(async move |this, cx| {
1828            let prediction = task.await?;
1829
1830            if prediction.is_none() && allow_jump {
1831                let cursor_point = position.to_point(&snapshot);
1832                if has_events
1833                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1834                        active_buffer.clone(),
1835                        &snapshot,
1836                        diagnostic_search_range,
1837                        cursor_point,
1838                        &project,
1839                        cx,
1840                    )
1841                    .await?
1842                {
1843                    return this
1844                        .update(cx, |this, cx| {
1845                            this.request_prediction_internal(
1846                                project,
1847                                jump_buffer,
1848                                jump_position,
1849                                trigger,
1850                                false,
1851                                cx,
1852                            )
1853                        })?
1854                        .await;
1855                }
1856
1857                return anyhow::Ok(None);
1858            }
1859
1860            Ok(prediction)
1861        })
1862    }
1863
1864    async fn next_diagnostic_location(
1865        active_buffer: Entity<Buffer>,
1866        active_buffer_snapshot: &BufferSnapshot,
1867        active_buffer_diagnostic_search_range: Range<Point>,
1868        active_buffer_cursor_point: Point,
1869        project: &Entity<Project>,
1870        cx: &mut AsyncApp,
1871    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1872        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1873        let mut jump_location = active_buffer_snapshot
1874            .diagnostic_groups(None)
1875            .into_iter()
1876            .filter_map(|(_, group)| {
1877                let range = &group.entries[group.primary_ix]
1878                    .range
1879                    .to_point(&active_buffer_snapshot);
1880                if range.overlaps(&active_buffer_diagnostic_search_range) {
1881                    None
1882                } else {
1883                    Some(range.start)
1884                }
1885            })
1886            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1887            .map(|position| {
1888                (
1889                    active_buffer.clone(),
1890                    active_buffer_snapshot.anchor_before(position),
1891                )
1892            });
1893
1894        if jump_location.is_none() {
1895            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1896                let file = buffer.file()?;
1897
1898                Some(ProjectPath {
1899                    worktree_id: file.worktree_id(cx),
1900                    path: file.path().clone(),
1901                })
1902            });
1903
1904            let buffer_task = project.update(cx, |project, cx| {
1905                let (path, _, _) = project
1906                    .diagnostic_summaries(false, cx)
1907                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1908                    .max_by_key(|(path, _, _)| {
1909                        // find the buffer with errors that shares most parent directories
1910                        path.path
1911                            .components()
1912                            .zip(
1913                                active_buffer_path
1914                                    .as_ref()
1915                                    .map(|p| p.path.components())
1916                                    .unwrap_or_default(),
1917                            )
1918                            .take_while(|(a, b)| a == b)
1919                            .count()
1920                    })?;
1921
1922                Some(project.open_buffer(path, cx))
1923            });
1924
1925            if let Some(buffer_task) = buffer_task {
1926                let closest_buffer = buffer_task.await?;
1927
1928                jump_location = closest_buffer
1929                    .read_with(cx, |buffer, _cx| {
1930                        buffer
1931                            .buffer_diagnostics(None)
1932                            .into_iter()
1933                            .min_by_key(|entry| entry.diagnostic.severity)
1934                            .map(|entry| entry.range.start)
1935                    })
1936                    .map(|position| (closest_buffer, position));
1937            }
1938        }
1939
1940        anyhow::Ok(jump_location)
1941    }
1942
1943    async fn send_raw_llm_request(
1944        request: RawCompletionRequest,
1945        client: Arc<Client>,
1946        custom_url: Option<Arc<Url>>,
1947        llm_token: LlmApiToken,
1948        app_version: Version,
1949        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1950        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1951    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1952        let url = if let Some(custom_url) = custom_url {
1953            custom_url.as_ref().clone()
1954        } else {
1955            client
1956                .http_client()
1957                .build_zed_llm_url("/predict_edits/raw", &[])?
1958        };
1959
1960        #[cfg(feature = "cli-support")]
1961        let cache_key = if let Some(cache) = eval_cache {
1962            use collections::FxHasher;
1963            use std::hash::{Hash, Hasher};
1964
1965            let mut hasher = FxHasher::default();
1966            url.hash(&mut hasher);
1967            let request_str = serde_json::to_string_pretty(&request)?;
1968            request_str.hash(&mut hasher);
1969            let hash = hasher.finish();
1970
1971            let key = (eval_cache_kind, hash);
1972            if let Some(response_str) = cache.read(key) {
1973                return Ok((serde_json::from_str(&response_str)?, None));
1974            }
1975
1976            Some((cache, request_str, key))
1977        } else {
1978            None
1979        };
1980
1981        let (response, usage) = Self::send_api_request(
1982            |builder| {
1983                let req = builder
1984                    .uri(url.as_ref())
1985                    .body(serde_json::to_string(&request)?.into());
1986                Ok(req?)
1987            },
1988            client,
1989            llm_token,
1990            app_version,
1991            true,
1992        )
1993        .await?;
1994
1995        #[cfg(feature = "cli-support")]
1996        if let Some((cache, request, key)) = cache_key {
1997            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1998        }
1999
2000        Ok((response, usage))
2001    }
2002
2003    fn handle_api_response<T>(
2004        this: &WeakEntity<Self>,
2005        response: Result<(T, Option<EditPredictionUsage>)>,
2006        cx: &mut gpui::AsyncApp,
2007    ) -> Result<T> {
2008        match response {
2009            Ok((data, usage)) => {
2010                if let Some(usage) = usage {
2011                    this.update(cx, |this, cx| {
2012                        this.user_store.update(cx, |user_store, cx| {
2013                            user_store.update_edit_prediction_usage(usage, cx);
2014                        });
2015                    })
2016                    .ok();
2017                }
2018                Ok(data)
2019            }
2020            Err(err) => {
2021                if err.is::<ZedUpdateRequiredError>() {
2022                    cx.update(|cx| {
2023                        this.update(cx, |this, _cx| {
2024                            this.update_required = true;
2025                        })
2026                        .ok();
2027
2028                        let error_message: SharedString = err.to_string().into();
2029                        show_app_notification(
2030                            NotificationId::unique::<ZedUpdateRequiredError>(),
2031                            cx,
2032                            move |cx| {
2033                                cx.new(|cx| {
2034                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2035                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2036                                })
2037                            },
2038                        );
2039                    });
2040                }
2041                Err(err)
2042            }
2043        }
2044    }
2045
2046    async fn send_api_request<Res>(
2047        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2048        client: Arc<Client>,
2049        llm_token: LlmApiToken,
2050        app_version: Version,
2051        require_auth: bool,
2052    ) -> Result<(Res, Option<EditPredictionUsage>)>
2053    where
2054        Res: DeserializeOwned,
2055    {
2056        let http_client = client.http_client();
2057
2058        let mut token = if require_auth {
2059            Some(llm_token.acquire(&client).await?)
2060        } else {
2061            llm_token.acquire(&client).await.ok()
2062        };
2063        let mut did_retry = false;
2064
2065        loop {
2066            let request_builder = http_client::Request::builder().method(Method::POST);
2067
2068            let mut request_builder = request_builder
2069                .header("Content-Type", "application/json")
2070                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2071
2072            // Only add Authorization header if we have a token
2073            if let Some(ref token_value) = token {
2074                request_builder =
2075                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2076            }
2077
2078            let request = build(request_builder)?;
2079
2080            let mut response = http_client.send(request).await?;
2081
2082            if let Some(minimum_required_version) = response
2083                .headers()
2084                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2085                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2086            {
2087                anyhow::ensure!(
2088                    app_version >= minimum_required_version,
2089                    ZedUpdateRequiredError {
2090                        minimum_version: minimum_required_version
2091                    }
2092                );
2093            }
2094
2095            if response.status().is_success() {
2096                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2097
2098                let mut body = Vec::new();
2099                response.body_mut().read_to_end(&mut body).await?;
2100                return Ok((serde_json::from_slice(&body)?, usage));
2101            } else if !did_retry
2102                && token.is_some()
2103                && response
2104                    .headers()
2105                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2106                    .is_some()
2107            {
2108                did_retry = true;
2109                token = Some(llm_token.refresh(&client).await?);
2110            } else {
2111                let mut body = String::new();
2112                response.body_mut().read_to_string(&mut body).await?;
2113                anyhow::bail!(
2114                    "Request failed with status: {:?}\nBody: {}",
2115                    response.status(),
2116                    body
2117                );
2118            }
2119        }
2120    }
2121
2122    pub fn refresh_context(
2123        &mut self,
2124        project: &Entity<Project>,
2125        buffer: &Entity<language::Buffer>,
2126        cursor_position: language::Anchor,
2127        cx: &mut Context<Self>,
2128    ) {
2129        if self.use_context {
2130            self.get_or_init_project(project, cx)
2131                .context
2132                .update(cx, |store, cx| {
2133                    store.refresh(buffer.clone(), cursor_position, cx);
2134                });
2135        }
2136    }
2137
2138    #[cfg(feature = "cli-support")]
2139    pub fn set_context_for_buffer(
2140        &mut self,
2141        project: &Entity<Project>,
2142        related_files: Vec<RelatedFile>,
2143        cx: &mut Context<Self>,
2144    ) {
2145        self.get_or_init_project(project, cx)
2146            .context
2147            .update(cx, |store, cx| {
2148                store.set_related_files(related_files, cx);
2149            });
2150    }
2151
2152    fn is_file_open_source(
2153        &self,
2154        project: &Entity<Project>,
2155        file: &Arc<dyn File>,
2156        cx: &App,
2157    ) -> bool {
2158        if !file.is_local() || file.is_private() {
2159            return false;
2160        }
2161        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2162            return false;
2163        };
2164        project_state
2165            .license_detection_watchers
2166            .get(&file.worktree_id(cx))
2167            .as_ref()
2168            .is_some_and(|watcher| watcher.is_project_open_source())
2169    }
2170
2171    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2172        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2173    }
2174
2175    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2176        if !self.data_collection_choice.is_enabled(cx) {
2177            return false;
2178        }
2179        events.iter().all(|event| {
2180            matches!(
2181                event.as_ref(),
2182                zeta_prompt::Event::BufferChange {
2183                    in_open_source_repo: true,
2184                    ..
2185                }
2186            )
2187        })
2188    }
2189
2190    fn load_data_collection_choice() -> DataCollectionChoice {
2191        let choice = KEY_VALUE_STORE
2192            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2193            .log_err()
2194            .flatten();
2195
2196        match choice.as_deref() {
2197            Some("true") => DataCollectionChoice::Enabled,
2198            Some("false") => DataCollectionChoice::Disabled,
2199            Some(_) => {
2200                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2201                DataCollectionChoice::NotAnswered
2202            }
2203            None => DataCollectionChoice::NotAnswered,
2204        }
2205    }
2206
2207    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2208        self.data_collection_choice = self.data_collection_choice.toggle();
2209        let new_choice = self.data_collection_choice;
2210        let is_enabled = new_choice.is_enabled(cx);
2211        db::write_and_log(cx, move || {
2212            KEY_VALUE_STORE.write_kvp(
2213                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2214                is_enabled.to_string(),
2215            )
2216        });
2217    }
2218
2219    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2220        self.shown_predictions.iter()
2221    }
2222
2223    pub fn shown_completions_len(&self) -> usize {
2224        self.shown_predictions.len()
2225    }
2226
2227    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2228        self.rated_predictions.contains(id)
2229    }
2230
2231    pub fn rate_prediction(
2232        &mut self,
2233        prediction: &EditPrediction,
2234        rating: EditPredictionRating,
2235        feedback: String,
2236        cx: &mut Context<Self>,
2237    ) {
2238        self.rated_predictions.insert(prediction.id.clone());
2239        telemetry::event!(
2240            "Edit Prediction Rated",
2241            rating,
2242            inputs = prediction.inputs,
2243            output = prediction
2244                .edit_preview
2245                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2246            feedback
2247        );
2248        self.client.telemetry().flush_events().detach();
2249        cx.notify();
2250    }
2251
2252    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2253        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2254            && all_language_settings(None, cx).edit_predictions.use_context;
2255    }
2256}
2257
2258pub(crate) fn filter_redundant_excerpts(
2259    mut related_files: Vec<RelatedFile>,
2260    cursor_path: &Path,
2261    cursor_row_range: Range<u32>,
2262) -> Vec<RelatedFile> {
2263    for file in &mut related_files {
2264        if file.path.as_ref() == cursor_path {
2265            file.excerpts.retain(|excerpt| {
2266                excerpt.row_range.start < cursor_row_range.start
2267                    || excerpt.row_range.end > cursor_row_range.end
2268            });
2269        }
2270    }
2271    related_files.retain(|file| !file.excerpts.is_empty());
2272    related_files
2273}
2274
2275#[derive(Error, Debug)]
2276#[error(
2277    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2278)]
2279pub struct ZedUpdateRequiredError {
2280    minimum_version: Version,
2281}
2282
2283#[cfg(feature = "cli-support")]
2284pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2285
2286#[cfg(feature = "cli-support")]
2287#[derive(Debug, Clone, Copy, PartialEq)]
2288pub enum EvalCacheEntryKind {
2289    Context,
2290    Search,
2291    Prediction,
2292}
2293
2294#[cfg(feature = "cli-support")]
2295impl std::fmt::Display for EvalCacheEntryKind {
2296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2297        match self {
2298            EvalCacheEntryKind::Search => write!(f, "search"),
2299            EvalCacheEntryKind::Context => write!(f, "context"),
2300            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2301        }
2302    }
2303}
2304
2305#[cfg(feature = "cli-support")]
2306pub trait EvalCache: Send + Sync {
2307    fn read(&self, key: EvalCacheKey) -> Option<String>;
2308    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2309}
2310
2311#[derive(Debug, Clone, Copy)]
2312pub enum DataCollectionChoice {
2313    NotAnswered,
2314    Enabled,
2315    Disabled,
2316}
2317
2318impl DataCollectionChoice {
2319    pub fn is_enabled(self, cx: &App) -> bool {
2320        if cx.is_staff() {
2321            return true;
2322        }
2323        match self {
2324            Self::Enabled => true,
2325            Self::NotAnswered | Self::Disabled => false,
2326        }
2327    }
2328
2329    #[must_use]
2330    pub fn toggle(&self) -> DataCollectionChoice {
2331        match self {
2332            Self::Enabled => Self::Disabled,
2333            Self::Disabled => Self::Enabled,
2334            Self::NotAnswered => Self::Enabled,
2335        }
2336    }
2337}
2338
2339impl From<bool> for DataCollectionChoice {
2340    fn from(value: bool) -> Self {
2341        match value {
2342            true => DataCollectionChoice::Enabled,
2343            false => DataCollectionChoice::Disabled,
2344        }
2345    }
2346}
2347
2348struct ZedPredictUpsell;
2349
2350impl Dismissable for ZedPredictUpsell {
2351    const KEY: &'static str = "dismissed-edit-predict-upsell";
2352
2353    fn dismissed() -> bool {
2354        // To make this backwards compatible with older versions of Zed, we
2355        // check if the user has seen the previous Edit Prediction Onboarding
2356        // before, by checking the data collection choice which was written to
2357        // the database once the user clicked on "Accept and Enable"
2358        if KEY_VALUE_STORE
2359            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2360            .log_err()
2361            .is_some_and(|s| s.is_some())
2362        {
2363            return true;
2364        }
2365
2366        KEY_VALUE_STORE
2367            .read_kvp(Self::KEY)
2368            .log_err()
2369            .is_some_and(|s| s.is_some())
2370    }
2371}
2372
2373pub fn should_show_upsell_modal() -> bool {
2374    !ZedPredictUpsell::dismissed()
2375}
2376
2377pub fn init(cx: &mut App) {
2378    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2379        workspace.register_action(
2380            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2381                ZedPredictModal::toggle(
2382                    workspace,
2383                    workspace.user_store().clone(),
2384                    workspace.client().clone(),
2385                    window,
2386                    cx,
2387                )
2388            },
2389        );
2390
2391        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2392            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2393                settings
2394                    .project
2395                    .all_languages
2396                    .features
2397                    .get_or_insert_default()
2398                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2399            });
2400        });
2401    })
2402    .detach();
2403}