edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::EditPredictionExcerptOptions;
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, RefreshLlmTokenListener};
  34use project::{Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{EditPredictionProvider, update_settings_file};
  39use std::collections::{VecDeque, hash_map};
  40use text::Edit;
  41use workspace::Workspace;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59mod onboarding_modal;
  60pub mod open_ai_response;
  61mod prediction;
  62pub mod sweep_ai;
  63
  64pub mod udiff;
  65
  66mod capture_example;
  67mod zed_edit_prediction_delegate;
  68pub mod zeta1;
  69pub mod zeta2;
  70
  71#[cfg(test)]
  72mod edit_prediction_tests;
  73
  74use crate::capture_example::{
  75    should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
  76};
  77use crate::license_detection::LicenseDetectionWatcher;
  78use crate::mercury::Mercury;
  79use crate::onboarding_modal::ZedPredictModal;
  80pub use crate::prediction::EditPrediction;
  81pub use crate::prediction::EditPredictionId;
  82use crate::prediction::EditPredictionResult;
  83pub use crate::sweep_ai::SweepAi;
  84pub use capture_example::capture_example;
  85pub use language_model::ApiKeyState;
  86pub use telemetry_events::EditPredictionRating;
  87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 105
 106pub struct SweepFeatureFlag;
 107
 108impl FeatureFlag for SweepFeatureFlag {
 109    const NAME: &str = "sweep-ai";
 110}
 111
 112pub struct MercuryFeatureFlag;
 113
 114impl FeatureFlag for MercuryFeatureFlag {
 115    const NAME: &str = "mercury";
 116}
 117
 118pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 119    context: EditPredictionExcerptOptions {
 120        max_bytes: 512,
 121        min_bytes: 128,
 122        target_before_cursor_over_total_bytes: 0.5,
 123    },
 124    prompt_format: PromptFormat::DEFAULT,
 125};
 126
 127static USE_OLLAMA: LazyLock<bool> =
 128    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 129
 130static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 131    match env::var("ZED_ZETA2_MODEL").as_deref() {
 132        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 133        Ok(model) => model,
 134        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 135        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 136    }
 137    .to_string()
 138});
 139
 140pub struct Zeta2FeatureFlag;
 141
 142impl FeatureFlag for Zeta2FeatureFlag {
 143    const NAME: &'static str = "zeta2";
 144
 145    fn enabled_for_staff() -> bool {
 146        true
 147    }
 148}
 149
 150pub struct EditPredictionExampleCaptureFeatureFlag;
 151
 152impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 153    const NAME: &'static str = "edit-prediction-example-capture";
 154
 155    fn enabled_for_staff() -> bool {
 156        true
 157    }
 158}
 159
 160#[derive(Clone)]
 161struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 162
 163impl Global for EditPredictionStoreGlobal {}
 164
 165pub struct EditPredictionStore {
 166    client: Arc<Client>,
 167    user_store: Entity<UserStore>,
 168    llm_token: LlmApiToken,
 169    _llm_token_subscription: Subscription,
 170    projects: HashMap<EntityId, ProjectState>,
 171    options: ZetaOptions,
 172    update_required: bool,
 173    #[cfg(feature = "cli-support")]
 174    eval_cache: Option<Arc<dyn EvalCache>>,
 175    edit_prediction_model: EditPredictionModel,
 176    pub sweep_ai: SweepAi,
 177    pub mercury: Mercury,
 178    data_collection_choice: DataCollectionChoice,
 179    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 180    shown_predictions: VecDeque<EditPrediction>,
 181    rated_predictions: HashSet<EditPredictionId>,
 182    custom_predict_edits_url: Option<Arc<Url>>,
 183}
 184
 185#[derive(Copy, Clone, Default, PartialEq, Eq)]
 186pub enum EditPredictionModel {
 187    #[default]
 188    Zeta1,
 189    Zeta2 {
 190        version: ZetaVersion,
 191    },
 192    Sweep,
 193    Mercury,
 194}
 195
 196#[derive(Clone)]
 197pub struct EditPredictionModelInput {
 198    project: Entity<Project>,
 199    buffer: Entity<Buffer>,
 200    snapshot: BufferSnapshot,
 201    position: Anchor,
 202    events: Vec<Arc<zeta_prompt::Event>>,
 203    related_files: Vec<RelatedFile>,
 204    recent_paths: VecDeque<ProjectPath>,
 205    trigger: PredictEditsRequestTrigger,
 206    diagnostic_search_range: Range<Point>,
 207    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 208    pub user_actions: Vec<UserActionRecord>,
 209}
 210
 211#[derive(Debug, Clone, PartialEq)]
 212pub struct ZetaOptions {
 213    pub context: EditPredictionExcerptOptions,
 214    pub prompt_format: predict_edits_v3::PromptFormat,
 215}
 216
 217#[derive(Debug)]
 218pub enum DebugEvent {
 219    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 220    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 221    EditPredictionStarted(EditPredictionStartedDebugEvent),
 222    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 223}
 224
 225#[derive(Debug)]
 226pub struct ContextRetrievalStartedDebugEvent {
 227    pub project_entity_id: EntityId,
 228    pub timestamp: Instant,
 229    pub search_prompt: String,
 230}
 231
 232#[derive(Debug)]
 233pub struct ContextRetrievalFinishedDebugEvent {
 234    pub project_entity_id: EntityId,
 235    pub timestamp: Instant,
 236    pub metadata: Vec<(&'static str, SharedString)>,
 237}
 238
 239#[derive(Debug)]
 240pub struct EditPredictionStartedDebugEvent {
 241    pub buffer: WeakEntity<Buffer>,
 242    pub position: Anchor,
 243    pub prompt: Option<String>,
 244}
 245
 246#[derive(Debug)]
 247pub struct EditPredictionFinishedDebugEvent {
 248    pub buffer: WeakEntity<Buffer>,
 249    pub position: Anchor,
 250    pub model_output: Option<String>,
 251}
 252
 253pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 254
 255const USER_ACTION_HISTORY_SIZE: usize = 16;
 256
 257#[derive(Clone, Debug)]
 258pub struct UserActionRecord {
 259    pub action_type: UserActionType,
 260    pub buffer_id: EntityId,
 261    pub line_number: u32,
 262    pub offset: usize,
 263    pub timestamp_epoch_ms: u64,
 264}
 265
 266#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 267pub enum UserActionType {
 268    InsertChar,
 269    InsertSelection,
 270    DeleteChar,
 271    DeleteSelection,
 272    CursorMovement,
 273}
 274
 275/// An event with associated metadata for reconstructing buffer state.
 276#[derive(Clone)]
 277pub struct StoredEvent {
 278    pub event: Arc<zeta_prompt::Event>,
 279    pub old_snapshot: TextBufferSnapshot,
 280}
 281
 282struct ProjectState {
 283    events: VecDeque<StoredEvent>,
 284    last_event: Option<LastEvent>,
 285    recent_paths: VecDeque<ProjectPath>,
 286    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 287    current_prediction: Option<CurrentEditPrediction>,
 288    next_pending_prediction_id: usize,
 289    pending_predictions: ArrayVec<PendingPrediction, 2>,
 290    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 291    last_prediction_refresh: Option<(EntityId, Instant)>,
 292    cancelled_predictions: HashSet<usize>,
 293    context: Entity<RelatedExcerptStore>,
 294    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 295    user_actions: VecDeque<UserActionRecord>,
 296    _subscription: gpui::Subscription,
 297    copilot: Option<Entity<Copilot>>,
 298}
 299
 300impl ProjectState {
 301    fn record_user_action(&mut self, action: UserActionRecord) {
 302        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 303            self.user_actions.pop_front();
 304        }
 305        self.user_actions.push_back(action);
 306    }
 307
 308    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 309        self.events
 310            .iter()
 311            .cloned()
 312            .chain(
 313                self.last_event
 314                    .as_ref()
 315                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 316            )
 317            .collect()
 318    }
 319
 320    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 321        self.events
 322            .iter()
 323            .cloned()
 324            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 325                let (one, two) = event.split_by_pause();
 326                let one = one.finalize(&self.license_detection_watchers, cx);
 327                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 328                one.into_iter().chain(two)
 329            }))
 330            .collect()
 331    }
 332
 333    fn cancel_pending_prediction(
 334        &mut self,
 335        pending_prediction: PendingPrediction,
 336        cx: &mut Context<EditPredictionStore>,
 337    ) {
 338        self.cancelled_predictions.insert(pending_prediction.id);
 339
 340        cx.spawn(async move |this, cx| {
 341            let Some(prediction_id) = pending_prediction.task.await else {
 342                return;
 343            };
 344
 345            this.update(cx, |this, _cx| {
 346                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 347            })
 348            .ok();
 349        })
 350        .detach()
 351    }
 352
 353    fn active_buffer(
 354        &self,
 355        project: &Entity<Project>,
 356        cx: &App,
 357    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 358        let project = project.read(cx);
 359        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 360        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 361        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 362        Some((active_buffer, registered_buffer.last_position))
 363    }
 364}
 365
 366#[derive(Debug, Clone)]
 367struct CurrentEditPrediction {
 368    pub requested_by: PredictionRequestedBy,
 369    pub prediction: EditPrediction,
 370    pub was_shown: bool,
 371    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 372}
 373
 374impl CurrentEditPrediction {
 375    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 376        let Some(new_edits) = self
 377            .prediction
 378            .interpolate(&self.prediction.buffer.read(cx))
 379        else {
 380            return false;
 381        };
 382
 383        if self.prediction.buffer != old_prediction.prediction.buffer {
 384            return true;
 385        }
 386
 387        let Some(old_edits) = old_prediction
 388            .prediction
 389            .interpolate(&old_prediction.prediction.buffer.read(cx))
 390        else {
 391            return true;
 392        };
 393
 394        let requested_by_buffer_id = self.requested_by.buffer_id();
 395
 396        // This reduces the occurrence of UI thrash from replacing edits
 397        //
 398        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 399        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 400            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 401            && old_edits.len() == 1
 402            && new_edits.len() == 1
 403        {
 404            let (old_range, old_text) = &old_edits[0];
 405            let (new_range, new_text) = &new_edits[0];
 406            new_range == old_range && new_text.starts_with(old_text.as_ref())
 407        } else {
 408            true
 409        }
 410    }
 411}
 412
 413#[derive(Debug, Clone)]
 414enum PredictionRequestedBy {
 415    DiagnosticsUpdate,
 416    Buffer(EntityId),
 417}
 418
 419impl PredictionRequestedBy {
 420    pub fn buffer_id(&self) -> Option<EntityId> {
 421        match self {
 422            PredictionRequestedBy::DiagnosticsUpdate => None,
 423            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 424        }
 425    }
 426}
 427
 428#[derive(Debug)]
 429struct PendingPrediction {
 430    id: usize,
 431    task: Task<Option<EditPredictionId>>,
 432}
 433
 434/// A prediction from the perspective of a buffer.
 435#[derive(Debug)]
 436enum BufferEditPrediction<'a> {
 437    Local { prediction: &'a EditPrediction },
 438    Jump { prediction: &'a EditPrediction },
 439}
 440
 441#[cfg(test)]
 442impl std::ops::Deref for BufferEditPrediction<'_> {
 443    type Target = EditPrediction;
 444
 445    fn deref(&self) -> &Self::Target {
 446        match self {
 447            BufferEditPrediction::Local { prediction } => prediction,
 448            BufferEditPrediction::Jump { prediction } => prediction,
 449        }
 450    }
 451}
 452
 453struct RegisteredBuffer {
 454    file: Option<Arc<dyn File>>,
 455    snapshot: TextBufferSnapshot,
 456    last_position: Option<Anchor>,
 457    _subscriptions: [gpui::Subscription; 2],
 458}
 459
 460#[derive(Clone)]
 461struct LastEvent {
 462    old_snapshot: TextBufferSnapshot,
 463    new_snapshot: TextBufferSnapshot,
 464    old_file: Option<Arc<dyn File>>,
 465    new_file: Option<Arc<dyn File>>,
 466    edit_range: Option<Range<Anchor>>,
 467    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 468    last_edit_time: Option<Instant>,
 469}
 470
 471impl LastEvent {
 472    pub fn finalize(
 473        &self,
 474        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 475        cx: &App,
 476    ) -> Option<StoredEvent> {
 477        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 478        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 479
 480        let in_open_source_repo =
 481            [self.new_file.as_ref(), self.old_file.as_ref()]
 482                .iter()
 483                .all(|file| {
 484                    file.is_some_and(|file| {
 485                        license_detection_watchers
 486                            .get(&file.worktree_id(cx))
 487                            .is_some_and(|watcher| watcher.is_project_open_source())
 488                    })
 489                });
 490
 491        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 492
 493        if path == old_path && diff.is_empty() {
 494            None
 495        } else {
 496            Some(StoredEvent {
 497                event: Arc::new(zeta_prompt::Event::BufferChange {
 498                    old_path,
 499                    path,
 500                    diff,
 501                    in_open_source_repo,
 502                    // TODO: Actually detect if this edit was predicted or not
 503                    predicted: false,
 504                }),
 505                old_snapshot: self.old_snapshot.clone(),
 506            })
 507        }
 508    }
 509
 510    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 511        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 512            return (self.clone(), None);
 513        };
 514
 515        let before = LastEvent {
 516            old_snapshot: self.old_snapshot.clone(),
 517            new_snapshot: boundary_snapshot.clone(),
 518            old_file: self.old_file.clone(),
 519            new_file: self.new_file.clone(),
 520            edit_range: None,
 521            snapshot_after_last_editing_pause: None,
 522            last_edit_time: self.last_edit_time,
 523        };
 524
 525        let after = LastEvent {
 526            old_snapshot: boundary_snapshot.clone(),
 527            new_snapshot: self.new_snapshot.clone(),
 528            old_file: self.old_file.clone(),
 529            new_file: self.new_file.clone(),
 530            edit_range: None,
 531            snapshot_after_last_editing_pause: None,
 532            last_edit_time: self.last_edit_time,
 533        };
 534
 535        (before, Some(after))
 536    }
 537}
 538
 539pub(crate) fn compute_diff_between_snapshots(
 540    old_snapshot: &TextBufferSnapshot,
 541    new_snapshot: &TextBufferSnapshot,
 542) -> Option<String> {
 543    let edits: Vec<Edit<usize>> = new_snapshot
 544        .edits_since::<usize>(&old_snapshot.version)
 545        .collect();
 546
 547    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 548
 549    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 550    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 551    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 552    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 553
 554    const CONTEXT_LINES: u32 = 3;
 555
 556    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 557    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 558    let old_context_end_row =
 559        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 560    let new_context_end_row =
 561        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 562
 563    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 564    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 565    let old_end_line_offset = old_snapshot
 566        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 567    let new_end_line_offset = new_snapshot
 568        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 569    let old_edit_range = old_start_line_offset..old_end_line_offset;
 570    let new_edit_range = new_start_line_offset..new_end_line_offset;
 571
 572    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 573    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 574
 575    let diff = language::unified_diff_with_offsets(
 576        &old_region_text,
 577        &new_region_text,
 578        old_context_start_row,
 579        new_context_start_row,
 580    );
 581
 582    Some(diff)
 583}
 584
 585fn buffer_path_with_id_fallback(
 586    file: Option<&Arc<dyn File>>,
 587    snapshot: &TextBufferSnapshot,
 588    cx: &App,
 589) -> Arc<Path> {
 590    if let Some(file) = file {
 591        file.full_path(cx).into()
 592    } else {
 593        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 594    }
 595}
 596
 597impl EditPredictionStore {
 598    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 599        cx.try_global::<EditPredictionStoreGlobal>()
 600            .map(|global| global.0.clone())
 601    }
 602
 603    pub fn global(
 604        client: &Arc<Client>,
 605        user_store: &Entity<UserStore>,
 606        cx: &mut App,
 607    ) -> Entity<Self> {
 608        cx.try_global::<EditPredictionStoreGlobal>()
 609            .map(|global| global.0.clone())
 610            .unwrap_or_else(|| {
 611                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 612                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 613                ep_store
 614            })
 615    }
 616
 617    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 618        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 619        let data_collection_choice = Self::load_data_collection_choice();
 620
 621        let llm_token = LlmApiToken::default();
 622
 623        let (reject_tx, reject_rx) = mpsc::unbounded();
 624        cx.background_spawn({
 625            let client = client.clone();
 626            let llm_token = llm_token.clone();
 627            let app_version = AppVersion::global(cx);
 628            let background_executor = cx.background_executor().clone();
 629            async move {
 630                Self::handle_rejected_predictions(
 631                    reject_rx,
 632                    client,
 633                    llm_token,
 634                    app_version,
 635                    background_executor,
 636                )
 637                .await
 638            }
 639        })
 640        .detach();
 641
 642        let this = Self {
 643            projects: HashMap::default(),
 644            client,
 645            user_store,
 646            options: DEFAULT_OPTIONS,
 647            llm_token,
 648            _llm_token_subscription: cx.subscribe(
 649                &refresh_llm_token_listener,
 650                |this, _listener, _event, cx| {
 651                    let client = this.client.clone();
 652                    let llm_token = this.llm_token.clone();
 653                    cx.spawn(async move |_this, _cx| {
 654                        llm_token.refresh(&client).await?;
 655                        anyhow::Ok(())
 656                    })
 657                    .detach_and_log_err(cx);
 658                },
 659            ),
 660            update_required: false,
 661            #[cfg(feature = "cli-support")]
 662            eval_cache: None,
 663            edit_prediction_model: EditPredictionModel::Zeta2 {
 664                version: Default::default(),
 665            },
 666            sweep_ai: SweepAi::new(cx),
 667            mercury: Mercury::new(cx),
 668
 669            data_collection_choice,
 670            reject_predictions_tx: reject_tx,
 671            rated_predictions: Default::default(),
 672            shown_predictions: Default::default(),
 673            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 674                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 675                Err(_) => {
 676                    if *USE_OLLAMA {
 677                        Some(
 678                            Url::parse("http://localhost:11434/v1/chat/completions")
 679                                .unwrap()
 680                                .into(),
 681                        )
 682                    } else {
 683                        None
 684                    }
 685                }
 686            },
 687        };
 688
 689        this
 690    }
 691
 692    #[cfg(test)]
 693    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 694        self.custom_predict_edits_url = Some(url.into());
 695    }
 696
 697    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 698        self.edit_prediction_model = model;
 699    }
 700
 701    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 702        self.sweep_ai.api_token.read(cx).has_key()
 703    }
 704
 705    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 706        self.mercury.api_token.read(cx).has_key()
 707    }
 708
 709    #[cfg(feature = "cli-support")]
 710    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 711        self.eval_cache = Some(cache);
 712    }
 713
 714    pub fn options(&self) -> &ZetaOptions {
 715        &self.options
 716    }
 717
 718    pub fn set_options(&mut self, options: ZetaOptions) {
 719        self.options = options;
 720    }
 721
 722    pub fn clear_history(&mut self) {
 723        for project_state in self.projects.values_mut() {
 724            project_state.events.clear();
 725            project_state.last_event.take();
 726        }
 727    }
 728
 729    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 730        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 731            project_state.events.clear();
 732            project_state.last_event.take();
 733        }
 734    }
 735
 736    pub fn edit_history_for_project(
 737        &self,
 738        project: &Entity<Project>,
 739        cx: &App,
 740    ) -> Vec<StoredEvent> {
 741        self.projects
 742            .get(&project.entity_id())
 743            .map(|project_state| project_state.events(cx))
 744            .unwrap_or_default()
 745    }
 746
 747    pub fn edit_history_for_project_with_pause_split_last_event(
 748        &self,
 749        project: &Entity<Project>,
 750        cx: &App,
 751    ) -> Vec<StoredEvent> {
 752        self.projects
 753            .get(&project.entity_id())
 754            .map(|project_state| project_state.events_split_by_pause(cx))
 755            .unwrap_or_default()
 756    }
 757
 758    pub fn context_for_project<'a>(
 759        &'a self,
 760        project: &Entity<Project>,
 761        cx: &'a mut App,
 762    ) -> Vec<RelatedFile> {
 763        self.projects
 764            .get(&project.entity_id())
 765            .map(|project| {
 766                project
 767                    .context
 768                    .update(cx, |context, cx| context.related_files(cx))
 769            })
 770            .unwrap_or_default()
 771    }
 772
 773    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 774        self.projects
 775            .get(&project.entity_id())
 776            .and_then(|project| project.copilot.clone())
 777    }
 778
 779    pub fn start_copilot_for_project(
 780        &mut self,
 781        project: &Entity<Project>,
 782        cx: &mut Context<Self>,
 783    ) -> Option<Entity<Copilot>> {
 784        let state = self.get_or_init_project(project, cx);
 785
 786        if state.copilot.is_some() {
 787            return state.copilot.clone();
 788        }
 789        let _project = project.clone();
 790        let project = project.read(cx);
 791
 792        let node = project.node_runtime().cloned();
 793        if let Some(node) = node {
 794            let next_id = project.languages().next_language_server_id();
 795            let fs = project.fs().clone();
 796
 797            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 798            state.copilot = Some(copilot.clone());
 799            Some(copilot)
 800        } else {
 801            None
 802        }
 803    }
 804
 805    pub fn context_for_project_with_buffers<'a>(
 806        &'a self,
 807        project: &Entity<Project>,
 808        cx: &'a mut App,
 809    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 810        self.projects
 811            .get(&project.entity_id())
 812            .map(|project| {
 813                project
 814                    .context
 815                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 816            })
 817            .unwrap_or_default()
 818    }
 819
 820    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 821        if matches!(
 822            self.edit_prediction_model,
 823            EditPredictionModel::Zeta2 { .. }
 824        ) {
 825            self.user_store.read(cx).edit_prediction_usage()
 826        } else {
 827            None
 828        }
 829    }
 830
 831    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 832        self.get_or_init_project(project, cx);
 833    }
 834
 835    pub fn register_buffer(
 836        &mut self,
 837        buffer: &Entity<Buffer>,
 838        project: &Entity<Project>,
 839        cx: &mut Context<Self>,
 840    ) {
 841        let project_state = self.get_or_init_project(project, cx);
 842        Self::register_buffer_impl(project_state, buffer, project, cx);
 843    }
 844
 845    fn get_or_init_project(
 846        &mut self,
 847        project: &Entity<Project>,
 848        cx: &mut Context<Self>,
 849    ) -> &mut ProjectState {
 850        let entity_id = project.entity_id();
 851        self.projects
 852            .entry(entity_id)
 853            .or_insert_with(|| ProjectState {
 854                context: {
 855                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 856                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 857                        this.handle_excerpt_store_event(entity_id, event);
 858                    })
 859                    .detach();
 860                    related_excerpt_store
 861                },
 862                events: VecDeque::new(),
 863                last_event: None,
 864                recent_paths: VecDeque::new(),
 865                debug_tx: None,
 866                registered_buffers: HashMap::default(),
 867                current_prediction: None,
 868                cancelled_predictions: HashSet::default(),
 869                pending_predictions: ArrayVec::new(),
 870                next_pending_prediction_id: 0,
 871                last_prediction_refresh: None,
 872                license_detection_watchers: HashMap::default(),
 873                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 874                _subscription: cx.subscribe(&project, Self::handle_project_event),
 875                copilot: None,
 876            })
 877    }
 878
 879    pub fn remove_project(&mut self, project: &Entity<Project>) {
 880        self.projects.remove(&project.entity_id());
 881    }
 882
 883    fn handle_excerpt_store_event(
 884        &mut self,
 885        project_entity_id: EntityId,
 886        event: &RelatedExcerptStoreEvent,
 887    ) {
 888        if let Some(project_state) = self.projects.get(&project_entity_id) {
 889            if let Some(debug_tx) = project_state.debug_tx.clone() {
 890                match event {
 891                    RelatedExcerptStoreEvent::StartedRefresh => {
 892                        debug_tx
 893                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 894                                ContextRetrievalStartedDebugEvent {
 895                                    project_entity_id: project_entity_id,
 896                                    timestamp: Instant::now(),
 897                                    search_prompt: String::new(),
 898                                },
 899                            ))
 900                            .ok();
 901                    }
 902                    RelatedExcerptStoreEvent::FinishedRefresh {
 903                        cache_hit_count,
 904                        cache_miss_count,
 905                        mean_definition_latency,
 906                        max_definition_latency,
 907                    } => {
 908                        debug_tx
 909                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 910                                ContextRetrievalFinishedDebugEvent {
 911                                    project_entity_id: project_entity_id,
 912                                    timestamp: Instant::now(),
 913                                    metadata: vec![
 914                                        (
 915                                            "Cache Hits",
 916                                            format!(
 917                                                "{}/{}",
 918                                                cache_hit_count,
 919                                                cache_hit_count + cache_miss_count
 920                                            )
 921                                            .into(),
 922                                        ),
 923                                        (
 924                                            "Max LSP Time",
 925                                            format!("{} ms", max_definition_latency.as_millis())
 926                                                .into(),
 927                                        ),
 928                                        (
 929                                            "Mean LSP Time",
 930                                            format!("{} ms", mean_definition_latency.as_millis())
 931                                                .into(),
 932                                        ),
 933                                    ],
 934                                },
 935                            ))
 936                            .ok();
 937                    }
 938                }
 939            }
 940        }
 941    }
 942
 943    pub fn debug_info(
 944        &mut self,
 945        project: &Entity<Project>,
 946        cx: &mut Context<Self>,
 947    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 948        let project_state = self.get_or_init_project(project, cx);
 949        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 950        project_state.debug_tx = Some(debug_watch_tx);
 951        debug_watch_rx
 952    }
 953
 954    fn handle_project_event(
 955        &mut self,
 956        project: Entity<Project>,
 957        event: &project::Event,
 958        cx: &mut Context<Self>,
 959    ) {
 960        // TODO [zeta2] init with recent paths
 961        match event {
 962            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 963                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 964                    return;
 965                };
 966                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 967                if let Some(path) = path {
 968                    if let Some(ix) = project_state
 969                        .recent_paths
 970                        .iter()
 971                        .position(|probe| probe == &path)
 972                    {
 973                        project_state.recent_paths.remove(ix);
 974                    }
 975                    project_state.recent_paths.push_front(path);
 976                }
 977            }
 978            project::Event::DiagnosticsUpdated { .. } => {
 979                if cx.has_flag::<Zeta2FeatureFlag>() {
 980                    self.refresh_prediction_from_diagnostics(project, cx);
 981                }
 982            }
 983            _ => (),
 984        }
 985    }
 986
 987    fn register_buffer_impl<'a>(
 988        project_state: &'a mut ProjectState,
 989        buffer: &Entity<Buffer>,
 990        project: &Entity<Project>,
 991        cx: &mut Context<Self>,
 992    ) -> &'a mut RegisteredBuffer {
 993        let buffer_id = buffer.entity_id();
 994
 995        if let Some(file) = buffer.read(cx).file() {
 996            let worktree_id = file.worktree_id(cx);
 997            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 998                project_state
 999                    .license_detection_watchers
1000                    .entry(worktree_id)
1001                    .or_insert_with(|| {
1002                        let project_entity_id = project.entity_id();
1003                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1004                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1005                            else {
1006                                return;
1007                            };
1008                            project_state
1009                                .license_detection_watchers
1010                                .remove(&worktree_id);
1011                        })
1012                        .detach();
1013                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1014                    });
1015            }
1016        }
1017
1018        match project_state.registered_buffers.entry(buffer_id) {
1019            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1020            hash_map::Entry::Vacant(entry) => {
1021                let buf = buffer.read(cx);
1022                let snapshot = buf.text_snapshot();
1023                let file = buf.file().cloned();
1024                let project_entity_id = project.entity_id();
1025                entry.insert(RegisteredBuffer {
1026                    snapshot,
1027                    file,
1028                    last_position: None,
1029                    _subscriptions: [
1030                        cx.subscribe(buffer, {
1031                            let project = project.downgrade();
1032                            move |this, buffer, event, cx| {
1033                                if let language::BufferEvent::Edited = event
1034                                    && let Some(project) = project.upgrade()
1035                                {
1036                                    this.report_changes_for_buffer(&buffer, &project, cx);
1037                                }
1038                            }
1039                        }),
1040                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1041                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1042                            else {
1043                                return;
1044                            };
1045                            project_state.registered_buffers.remove(&buffer_id);
1046                        }),
1047                    ],
1048                })
1049            }
1050        }
1051    }
1052
1053    fn report_changes_for_buffer(
1054        &mut self,
1055        buffer: &Entity<Buffer>,
1056        project: &Entity<Project>,
1057        cx: &mut Context<Self>,
1058    ) {
1059        let project_state = self.get_or_init_project(project, cx);
1060        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1061
1062        let buf = buffer.read(cx);
1063        let new_file = buf.file().cloned();
1064        let new_snapshot = buf.text_snapshot();
1065        if new_snapshot.version == registered_buffer.snapshot.version {
1066            return;
1067        }
1068
1069        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1070        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1071        let mut num_edits = 0usize;
1072        let mut total_deleted = 0usize;
1073        let mut total_inserted = 0usize;
1074        let mut edit_range: Option<Range<Anchor>> = None;
1075        let mut last_offset: Option<usize> = None;
1076
1077        for (edit, anchor_range) in
1078            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1079        {
1080            num_edits += 1;
1081            total_deleted += edit.old.len();
1082            total_inserted += edit.new.len();
1083            edit_range = Some(match edit_range {
1084                None => anchor_range,
1085                Some(acc) => acc.start..anchor_range.end,
1086            });
1087            last_offset = Some(edit.new.end);
1088        }
1089
1090        if num_edits > 0 {
1091            let action_type = match (total_deleted, total_inserted, num_edits) {
1092                (0, ins, n) if ins == n => UserActionType::InsertChar,
1093                (0, _, _) => UserActionType::InsertSelection,
1094                (del, 0, n) if del == n => UserActionType::DeleteChar,
1095                (_, 0, _) => UserActionType::DeleteSelection,
1096                (_, ins, n) if ins == n => UserActionType::InsertChar,
1097                (_, _, _) => UserActionType::InsertSelection,
1098            };
1099
1100            if let Some(offset) = last_offset {
1101                let point = new_snapshot.offset_to_point(offset);
1102                let timestamp_epoch_ms = SystemTime::now()
1103                    .duration_since(UNIX_EPOCH)
1104                    .map(|d| d.as_millis() as u64)
1105                    .unwrap_or(0);
1106                project_state.record_user_action(UserActionRecord {
1107                    action_type,
1108                    buffer_id: buffer.entity_id(),
1109                    line_number: point.row,
1110                    offset,
1111                    timestamp_epoch_ms,
1112                });
1113            }
1114        }
1115
1116        let events = &mut project_state.events;
1117
1118        let now = cx.background_executor().now();
1119        if let Some(last_event) = project_state.last_event.as_mut() {
1120            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1121                == last_event.new_snapshot.remote_id()
1122                && old_snapshot.version == last_event.new_snapshot.version;
1123
1124            let should_coalesce = is_next_snapshot_of_same_buffer
1125                && edit_range
1126                    .as_ref()
1127                    .zip(last_event.edit_range.as_ref())
1128                    .is_some_and(|(a, b)| {
1129                        let a = a.to_point(&new_snapshot);
1130                        let b = b.to_point(&new_snapshot);
1131                        if a.start > b.end {
1132                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1133                        } else if b.start > a.end {
1134                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1135                        } else {
1136                            true
1137                        }
1138                    });
1139
1140            if should_coalesce {
1141                let pause_elapsed = last_event
1142                    .last_edit_time
1143                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1144                    .unwrap_or(false);
1145                if pause_elapsed {
1146                    last_event.snapshot_after_last_editing_pause =
1147                        Some(last_event.new_snapshot.clone());
1148                }
1149
1150                last_event.edit_range = edit_range;
1151                last_event.new_snapshot = new_snapshot;
1152                last_event.last_edit_time = Some(now);
1153                return;
1154            }
1155        }
1156
1157        if let Some(event) = project_state.last_event.take() {
1158            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1159                if events.len() + 1 >= EVENT_COUNT_MAX {
1160                    events.pop_front();
1161                }
1162                events.push_back(event);
1163            }
1164        }
1165
1166        project_state.last_event = Some(LastEvent {
1167            old_file,
1168            new_file,
1169            old_snapshot,
1170            new_snapshot,
1171            edit_range,
1172            snapshot_after_last_editing_pause: None,
1173            last_edit_time: Some(now),
1174        });
1175    }
1176
1177    fn prediction_at(
1178        &mut self,
1179        buffer: &Entity<Buffer>,
1180        position: Option<language::Anchor>,
1181        project: &Entity<Project>,
1182        cx: &App,
1183    ) -> Option<BufferEditPrediction<'_>> {
1184        let project_state = self.projects.get_mut(&project.entity_id())?;
1185        if let Some(position) = position
1186            && let Some(buffer) = project_state
1187                .registered_buffers
1188                .get_mut(&buffer.entity_id())
1189        {
1190            buffer.last_position = Some(position);
1191        }
1192
1193        let CurrentEditPrediction {
1194            requested_by,
1195            prediction,
1196            ..
1197        } = project_state.current_prediction.as_ref()?;
1198
1199        if prediction.targets_buffer(buffer.read(cx)) {
1200            Some(BufferEditPrediction::Local { prediction })
1201        } else {
1202            let show_jump = match requested_by {
1203                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1204                    requested_by_buffer_id == &buffer.entity_id()
1205                }
1206                PredictionRequestedBy::DiagnosticsUpdate => true,
1207            };
1208
1209            if show_jump {
1210                Some(BufferEditPrediction::Jump { prediction })
1211            } else {
1212                None
1213            }
1214        }
1215    }
1216
1217    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1218        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1219            return;
1220        };
1221
1222        let Some(current_prediction) = project_state.current_prediction.take() else {
1223            return;
1224        };
1225
1226        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1227            project_state.cancel_pending_prediction(pending_prediction, cx);
1228        }
1229
1230        match self.edit_prediction_model {
1231            EditPredictionModel::Sweep => {
1232                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1233            }
1234            EditPredictionModel::Mercury => {}
1235            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1236                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1237            }
1238        }
1239    }
1240
1241    async fn handle_rejected_predictions(
1242        rx: UnboundedReceiver<EditPredictionRejection>,
1243        client: Arc<Client>,
1244        llm_token: LlmApiToken,
1245        app_version: Version,
1246        background_executor: BackgroundExecutor,
1247    ) {
1248        let mut rx = std::pin::pin!(rx.peekable());
1249        let mut batched = Vec::new();
1250
1251        while let Some(rejection) = rx.next().await {
1252            batched.push(rejection);
1253
1254            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1255                select_biased! {
1256                    next = rx.as_mut().peek().fuse() => {
1257                        if next.is_some() {
1258                            continue;
1259                        }
1260                    }
1261                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1262                }
1263            }
1264
1265            let url = client
1266                .http_client()
1267                .build_zed_llm_url("/predict_edits/reject", &[])
1268                .unwrap();
1269
1270            let flush_count = batched
1271                .len()
1272                // in case items have accumulated after failure
1273                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1274            let start = batched.len() - flush_count;
1275
1276            let body = RejectEditPredictionsBodyRef {
1277                rejections: &batched[start..],
1278            };
1279
1280            let result = Self::send_api_request::<()>(
1281                |builder| {
1282                    let req = builder
1283                        .uri(url.as_ref())
1284                        .body(serde_json::to_string(&body)?.into());
1285                    anyhow::Ok(req?)
1286                },
1287                client.clone(),
1288                llm_token.clone(),
1289                app_version.clone(),
1290                true,
1291            )
1292            .await;
1293
1294            if result.log_err().is_some() {
1295                batched.drain(start..);
1296            }
1297        }
1298    }
1299
1300    fn reject_current_prediction(
1301        &mut self,
1302        reason: EditPredictionRejectReason,
1303        project: &Entity<Project>,
1304    ) {
1305        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1306            project_state.pending_predictions.clear();
1307            if let Some(prediction) = project_state.current_prediction.take() {
1308                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1309            }
1310        };
1311    }
1312
1313    fn did_show_current_prediction(
1314        &mut self,
1315        project: &Entity<Project>,
1316        display_type: edit_prediction_types::SuggestionDisplayType,
1317        cx: &mut Context<Self>,
1318    ) {
1319        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1320            return;
1321        };
1322
1323        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1324            return;
1325        };
1326
1327        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1328        let previous_shown_with = current_prediction.shown_with;
1329
1330        if previous_shown_with.is_none() || !is_jump {
1331            current_prediction.shown_with = Some(display_type);
1332        }
1333
1334        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1335
1336        if is_first_non_jump_show {
1337            current_prediction.was_shown = true;
1338        }
1339
1340        let display_type_changed = previous_shown_with != Some(display_type);
1341
1342        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1343            sweep_ai::edit_prediction_shown(
1344                &self.sweep_ai,
1345                self.client.clone(),
1346                &current_prediction.prediction,
1347                display_type,
1348                cx,
1349            );
1350        }
1351
1352        if is_first_non_jump_show {
1353            self.shown_predictions
1354                .push_front(current_prediction.prediction.clone());
1355            if self.shown_predictions.len() > 50 {
1356                let completion = self.shown_predictions.pop_back().unwrap();
1357                self.rated_predictions.remove(&completion.id);
1358            }
1359        }
1360    }
1361
1362    fn reject_prediction(
1363        &mut self,
1364        prediction_id: EditPredictionId,
1365        reason: EditPredictionRejectReason,
1366        was_shown: bool,
1367    ) {
1368        match self.edit_prediction_model {
1369            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1370                if self.custom_predict_edits_url.is_some() {
1371                    return;
1372                }
1373            }
1374            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1375        }
1376
1377        self.reject_predictions_tx
1378            .unbounded_send(EditPredictionRejection {
1379                request_id: prediction_id.to_string(),
1380                reason,
1381                was_shown,
1382            })
1383            .log_err();
1384    }
1385
1386    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1387        self.projects
1388            .get(&project.entity_id())
1389            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1390    }
1391
1392    pub fn refresh_prediction_from_buffer(
1393        &mut self,
1394        project: Entity<Project>,
1395        buffer: Entity<Buffer>,
1396        position: language::Anchor,
1397        cx: &mut Context<Self>,
1398    ) {
1399        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1400            let Some(request_task) = this
1401                .update(cx, |this, cx| {
1402                    this.request_prediction(
1403                        &project,
1404                        &buffer,
1405                        position,
1406                        PredictEditsRequestTrigger::Other,
1407                        cx,
1408                    )
1409                })
1410                .log_err()
1411            else {
1412                return Task::ready(anyhow::Ok(None));
1413            };
1414
1415            cx.spawn(async move |_cx| {
1416                request_task.await.map(|prediction_result| {
1417                    prediction_result.map(|prediction_result| {
1418                        (
1419                            prediction_result,
1420                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1421                        )
1422                    })
1423                })
1424            })
1425        })
1426    }
1427
1428    pub fn refresh_prediction_from_diagnostics(
1429        &mut self,
1430        project: Entity<Project>,
1431        cx: &mut Context<Self>,
1432    ) {
1433        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1434            return;
1435        };
1436
1437        // Prefer predictions from buffer
1438        if project_state.current_prediction.is_some() {
1439            return;
1440        };
1441
1442        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1443            let Some((active_buffer, snapshot, cursor_point)) = this
1444                .read_with(cx, |this, cx| {
1445                    let project_state = this.projects.get(&project.entity_id())?;
1446                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1447                    let snapshot = buffer.read(cx).snapshot();
1448
1449                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1450                        return None;
1451                    }
1452
1453                    let cursor_point = position
1454                        .map(|pos| pos.to_point(&snapshot))
1455                        .unwrap_or_default();
1456
1457                    Some((buffer, snapshot, cursor_point))
1458                })
1459                .log_err()
1460                .flatten()
1461            else {
1462                return Task::ready(anyhow::Ok(None));
1463            };
1464
1465            cx.spawn(async move |cx| {
1466                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1467                    active_buffer,
1468                    &snapshot,
1469                    Default::default(),
1470                    cursor_point,
1471                    &project,
1472                    cx,
1473                )
1474                .await?
1475                else {
1476                    return anyhow::Ok(None);
1477                };
1478
1479                let Some(prediction_result) = this
1480                    .update(cx, |this, cx| {
1481                        this.request_prediction(
1482                            &project,
1483                            &jump_buffer,
1484                            jump_position,
1485                            PredictEditsRequestTrigger::Diagnostics,
1486                            cx,
1487                        )
1488                    })?
1489                    .await?
1490                else {
1491                    return anyhow::Ok(None);
1492                };
1493
1494                this.update(cx, |this, cx| {
1495                    Some((
1496                        if this
1497                            .get_or_init_project(&project, cx)
1498                            .current_prediction
1499                            .is_none()
1500                        {
1501                            prediction_result
1502                        } else {
1503                            EditPredictionResult {
1504                                id: prediction_result.id,
1505                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1506                            }
1507                        },
1508                        PredictionRequestedBy::DiagnosticsUpdate,
1509                    ))
1510                })
1511            })
1512        });
1513    }
1514
1515    fn predictions_enabled_at(
1516        snapshot: &BufferSnapshot,
1517        position: Option<language::Anchor>,
1518        cx: &App,
1519    ) -> bool {
1520        let file = snapshot.file();
1521        let all_settings = all_language_settings(file, cx);
1522        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1523            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1524        {
1525            return false;
1526        }
1527
1528        if let Some(last_position) = position {
1529            let settings = snapshot.settings_at(last_position, cx);
1530
1531            if !settings.edit_predictions_disabled_in.is_empty()
1532                && let Some(scope) = snapshot.language_scope_at(last_position)
1533                && let Some(scope_name) = scope.override_name()
1534                && settings
1535                    .edit_predictions_disabled_in
1536                    .iter()
1537                    .any(|s| s == scope_name)
1538            {
1539                return false;
1540            }
1541        }
1542
1543        true
1544    }
1545
1546    #[cfg(not(test))]
1547    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1548    #[cfg(test)]
1549    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1550
1551    fn queue_prediction_refresh(
1552        &mut self,
1553        project: Entity<Project>,
1554        throttle_entity: EntityId,
1555        cx: &mut Context<Self>,
1556        do_refresh: impl FnOnce(
1557            WeakEntity<Self>,
1558            &mut AsyncApp,
1559        )
1560            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1561        + 'static,
1562    ) {
1563        let project_state = self.get_or_init_project(&project, cx);
1564        let pending_prediction_id = project_state.next_pending_prediction_id;
1565        project_state.next_pending_prediction_id += 1;
1566        let last_request = project_state.last_prediction_refresh;
1567
1568        let task = cx.spawn(async move |this, cx| {
1569            if let Some((last_entity, last_timestamp)) = last_request
1570                && throttle_entity == last_entity
1571                && let Some(timeout) =
1572                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1573            {
1574                cx.background_executor().timer(timeout).await;
1575            }
1576
1577            // If this task was cancelled before the throttle timeout expired,
1578            // do not perform a request.
1579            let mut is_cancelled = true;
1580            this.update(cx, |this, cx| {
1581                let project_state = this.get_or_init_project(&project, cx);
1582                if !project_state
1583                    .cancelled_predictions
1584                    .remove(&pending_prediction_id)
1585                {
1586                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1587                    is_cancelled = false;
1588                }
1589            })
1590            .ok();
1591            if is_cancelled {
1592                return None;
1593            }
1594
1595            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1596            let new_prediction_id = new_prediction_result
1597                .as_ref()
1598                .map(|(prediction, _)| prediction.id.clone());
1599
1600            // When a prediction completes, remove it from the pending list, and cancel
1601            // any pending predictions that were enqueued before it.
1602            this.update(cx, |this, cx| {
1603                let project_state = this.get_or_init_project(&project, cx);
1604
1605                let is_cancelled = project_state
1606                    .cancelled_predictions
1607                    .remove(&pending_prediction_id);
1608
1609                let new_current_prediction = if !is_cancelled
1610                    && let Some((prediction_result, requested_by)) = new_prediction_result
1611                {
1612                    match prediction_result.prediction {
1613                        Ok(prediction) => {
1614                            let new_prediction = CurrentEditPrediction {
1615                                requested_by,
1616                                prediction,
1617                                was_shown: false,
1618                                shown_with: None,
1619                            };
1620
1621                            if let Some(current_prediction) =
1622                                project_state.current_prediction.as_ref()
1623                            {
1624                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1625                                {
1626                                    this.reject_current_prediction(
1627                                        EditPredictionRejectReason::Replaced,
1628                                        &project,
1629                                    );
1630
1631                                    Some(new_prediction)
1632                                } else {
1633                                    this.reject_prediction(
1634                                        new_prediction.prediction.id,
1635                                        EditPredictionRejectReason::CurrentPreferred,
1636                                        false,
1637                                    );
1638                                    None
1639                                }
1640                            } else {
1641                                Some(new_prediction)
1642                            }
1643                        }
1644                        Err(reject_reason) => {
1645                            this.reject_prediction(prediction_result.id, reject_reason, false);
1646                            None
1647                        }
1648                    }
1649                } else {
1650                    None
1651                };
1652
1653                let project_state = this.get_or_init_project(&project, cx);
1654
1655                if let Some(new_prediction) = new_current_prediction {
1656                    project_state.current_prediction = Some(new_prediction);
1657                }
1658
1659                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1660                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1661                    if pending_prediction.id == pending_prediction_id {
1662                        pending_predictions.remove(ix);
1663                        for pending_prediction in pending_predictions.drain(0..ix) {
1664                            project_state.cancel_pending_prediction(pending_prediction, cx)
1665                        }
1666                        break;
1667                    }
1668                }
1669                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1670                cx.notify();
1671            })
1672            .ok();
1673
1674            new_prediction_id
1675        });
1676
1677        if project_state.pending_predictions.len() <= 1 {
1678            project_state.pending_predictions.push(PendingPrediction {
1679                id: pending_prediction_id,
1680                task,
1681            });
1682        } else if project_state.pending_predictions.len() == 2 {
1683            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1684            project_state.pending_predictions.push(PendingPrediction {
1685                id: pending_prediction_id,
1686                task,
1687            });
1688            project_state.cancel_pending_prediction(pending_prediction, cx);
1689        }
1690    }
1691
1692    pub fn request_prediction(
1693        &mut self,
1694        project: &Entity<Project>,
1695        active_buffer: &Entity<Buffer>,
1696        position: language::Anchor,
1697        trigger: PredictEditsRequestTrigger,
1698        cx: &mut Context<Self>,
1699    ) -> Task<Result<Option<EditPredictionResult>>> {
1700        self.request_prediction_internal(
1701            project.clone(),
1702            active_buffer.clone(),
1703            position,
1704            trigger,
1705            cx.has_flag::<Zeta2FeatureFlag>(),
1706            cx,
1707        )
1708    }
1709
1710    fn request_prediction_internal(
1711        &mut self,
1712        project: Entity<Project>,
1713        active_buffer: Entity<Buffer>,
1714        position: language::Anchor,
1715        trigger: PredictEditsRequestTrigger,
1716        allow_jump: bool,
1717        cx: &mut Context<Self>,
1718    ) -> Task<Result<Option<EditPredictionResult>>> {
1719        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1720
1721        self.get_or_init_project(&project, cx);
1722        let project_state = self.projects.get(&project.entity_id()).unwrap();
1723        let stored_events = project_state.events(cx);
1724        let has_events = !stored_events.is_empty();
1725        let events: Vec<Arc<zeta_prompt::Event>> =
1726            stored_events.into_iter().map(|e| e.event).collect();
1727        let debug_tx = project_state.debug_tx.clone();
1728
1729        let snapshot = active_buffer.read(cx).snapshot();
1730        let cursor_point = position.to_point(&snapshot);
1731        let current_offset = position.to_offset(&snapshot);
1732
1733        let mut user_actions: Vec<UserActionRecord> =
1734            project_state.user_actions.iter().cloned().collect();
1735
1736        if let Some(last_action) = user_actions.last() {
1737            if last_action.buffer_id == active_buffer.entity_id()
1738                && current_offset != last_action.offset
1739            {
1740                let timestamp_epoch_ms = SystemTime::now()
1741                    .duration_since(UNIX_EPOCH)
1742                    .map(|d| d.as_millis() as u64)
1743                    .unwrap_or(0);
1744                user_actions.push(UserActionRecord {
1745                    action_type: UserActionType::CursorMovement,
1746                    buffer_id: active_buffer.entity_id(),
1747                    line_number: cursor_point.row,
1748                    offset: current_offset,
1749                    timestamp_epoch_ms,
1750                });
1751            }
1752        }
1753        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1754        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1755        let diagnostic_search_range =
1756            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1757
1758        let related_files = self.context_for_project(&project, cx);
1759
1760        let inputs = EditPredictionModelInput {
1761            project: project.clone(),
1762            buffer: active_buffer.clone(),
1763            snapshot: snapshot.clone(),
1764            position,
1765            events,
1766            related_files,
1767            recent_paths: project_state.recent_paths.clone(),
1768            trigger,
1769            diagnostic_search_range: diagnostic_search_range.clone(),
1770            debug_tx,
1771            user_actions,
1772        };
1773
1774        let can_collect_example = snapshot
1775            .file()
1776            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1777            && self.can_collect_events(&inputs.events, cx);
1778
1779        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1780            let events_for_capture =
1781                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1782            if let Some(example_task) = capture_example::capture_example(
1783                project.clone(),
1784                active_buffer.clone(),
1785                position,
1786                events_for_capture,
1787                false,
1788                cx,
1789            ) {
1790                cx.spawn(async move |_this, _cx| {
1791                    let example = example_task.await?;
1792                    telemetry::event!("Edit Prediction Example Captured", example = example);
1793                    anyhow::Ok(())
1794                })
1795                .detach_and_log_err(cx);
1796            }
1797        }
1798        let task = match self.edit_prediction_model {
1799            EditPredictionModel::Zeta1 => {
1800                if should_send_testing_zeta2_request() {
1801                    let mut zeta2_inputs = inputs.clone();
1802                    zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1803                    zeta2::request_prediction_with_zeta2(
1804                        self,
1805                        zeta2_inputs,
1806                        Default::default(),
1807                        cx,
1808                    )
1809                    .detach();
1810                }
1811                zeta1::request_prediction_with_zeta1(self, inputs, cx)
1812            }
1813            EditPredictionModel::Zeta2 { version } => {
1814                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1815            }
1816            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1817            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1818        };
1819
1820        cx.spawn(async move |this, cx| {
1821            let prediction = task.await?;
1822
1823            if prediction.is_none() && allow_jump {
1824                let cursor_point = position.to_point(&snapshot);
1825                if has_events
1826                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1827                        active_buffer.clone(),
1828                        &snapshot,
1829                        diagnostic_search_range,
1830                        cursor_point,
1831                        &project,
1832                        cx,
1833                    )
1834                    .await?
1835                {
1836                    return this
1837                        .update(cx, |this, cx| {
1838                            this.request_prediction_internal(
1839                                project,
1840                                jump_buffer,
1841                                jump_position,
1842                                trigger,
1843                                false,
1844                                cx,
1845                            )
1846                        })?
1847                        .await;
1848                }
1849
1850                return anyhow::Ok(None);
1851            }
1852
1853            Ok(prediction)
1854        })
1855    }
1856
1857    async fn next_diagnostic_location(
1858        active_buffer: Entity<Buffer>,
1859        active_buffer_snapshot: &BufferSnapshot,
1860        active_buffer_diagnostic_search_range: Range<Point>,
1861        active_buffer_cursor_point: Point,
1862        project: &Entity<Project>,
1863        cx: &mut AsyncApp,
1864    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1865        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1866        let mut jump_location = active_buffer_snapshot
1867            .diagnostic_groups(None)
1868            .into_iter()
1869            .filter_map(|(_, group)| {
1870                let range = &group.entries[group.primary_ix]
1871                    .range
1872                    .to_point(&active_buffer_snapshot);
1873                if range.overlaps(&active_buffer_diagnostic_search_range) {
1874                    None
1875                } else {
1876                    Some(range.start)
1877                }
1878            })
1879            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1880            .map(|position| {
1881                (
1882                    active_buffer.clone(),
1883                    active_buffer_snapshot.anchor_before(position),
1884                )
1885            });
1886
1887        if jump_location.is_none() {
1888            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1889                let file = buffer.file()?;
1890
1891                Some(ProjectPath {
1892                    worktree_id: file.worktree_id(cx),
1893                    path: file.path().clone(),
1894                })
1895            });
1896
1897            let buffer_task = project.update(cx, |project, cx| {
1898                let (path, _, _) = project
1899                    .diagnostic_summaries(false, cx)
1900                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1901                    .max_by_key(|(path, _, _)| {
1902                        // find the buffer with errors that shares most parent directories
1903                        path.path
1904                            .components()
1905                            .zip(
1906                                active_buffer_path
1907                                    .as_ref()
1908                                    .map(|p| p.path.components())
1909                                    .unwrap_or_default(),
1910                            )
1911                            .take_while(|(a, b)| a == b)
1912                            .count()
1913                    })?;
1914
1915                Some(project.open_buffer(path, cx))
1916            });
1917
1918            if let Some(buffer_task) = buffer_task {
1919                let closest_buffer = buffer_task.await?;
1920
1921                jump_location = closest_buffer
1922                    .read_with(cx, |buffer, _cx| {
1923                        buffer
1924                            .buffer_diagnostics(None)
1925                            .into_iter()
1926                            .min_by_key(|entry| entry.diagnostic.severity)
1927                            .map(|entry| entry.range.start)
1928                    })
1929                    .map(|position| (closest_buffer, position));
1930            }
1931        }
1932
1933        anyhow::Ok(jump_location)
1934    }
1935
1936    async fn send_raw_llm_request(
1937        request: RawCompletionRequest,
1938        client: Arc<Client>,
1939        custom_url: Option<Arc<Url>>,
1940        llm_token: LlmApiToken,
1941        app_version: Version,
1942        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1943        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1944    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1945        let url = if let Some(custom_url) = custom_url {
1946            custom_url.as_ref().clone()
1947        } else {
1948            client
1949                .http_client()
1950                .build_zed_llm_url("/predict_edits/raw", &[])?
1951        };
1952
1953        #[cfg(feature = "cli-support")]
1954        let cache_key = if let Some(cache) = eval_cache {
1955            use collections::FxHasher;
1956            use std::hash::{Hash, Hasher};
1957
1958            let mut hasher = FxHasher::default();
1959            url.hash(&mut hasher);
1960            let request_str = serde_json::to_string_pretty(&request)?;
1961            request_str.hash(&mut hasher);
1962            let hash = hasher.finish();
1963
1964            let key = (eval_cache_kind, hash);
1965            if let Some(response_str) = cache.read(key) {
1966                return Ok((serde_json::from_str(&response_str)?, None));
1967            }
1968
1969            Some((cache, request_str, key))
1970        } else {
1971            None
1972        };
1973
1974        let (response, usage) = Self::send_api_request(
1975            |builder| {
1976                let req = builder
1977                    .uri(url.as_ref())
1978                    .body(serde_json::to_string(&request)?.into());
1979                Ok(req?)
1980            },
1981            client,
1982            llm_token,
1983            app_version,
1984            true,
1985        )
1986        .await?;
1987
1988        #[cfg(feature = "cli-support")]
1989        if let Some((cache, request, key)) = cache_key {
1990            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1991        }
1992
1993        Ok((response, usage))
1994    }
1995
1996    fn handle_api_response<T>(
1997        this: &WeakEntity<Self>,
1998        response: Result<(T, Option<EditPredictionUsage>)>,
1999        cx: &mut gpui::AsyncApp,
2000    ) -> Result<T> {
2001        match response {
2002            Ok((data, usage)) => {
2003                if let Some(usage) = usage {
2004                    this.update(cx, |this, cx| {
2005                        this.user_store.update(cx, |user_store, cx| {
2006                            user_store.update_edit_prediction_usage(usage, cx);
2007                        });
2008                    })
2009                    .ok();
2010                }
2011                Ok(data)
2012            }
2013            Err(err) => {
2014                if err.is::<ZedUpdateRequiredError>() {
2015                    cx.update(|cx| {
2016                        this.update(cx, |this, _cx| {
2017                            this.update_required = true;
2018                        })
2019                        .ok();
2020
2021                        let error_message: SharedString = err.to_string().into();
2022                        show_app_notification(
2023                            NotificationId::unique::<ZedUpdateRequiredError>(),
2024                            cx,
2025                            move |cx| {
2026                                cx.new(|cx| {
2027                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2028                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2029                                })
2030                            },
2031                        );
2032                    });
2033                }
2034                Err(err)
2035            }
2036        }
2037    }
2038
2039    async fn send_api_request<Res>(
2040        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2041        client: Arc<Client>,
2042        llm_token: LlmApiToken,
2043        app_version: Version,
2044        require_auth: bool,
2045    ) -> Result<(Res, Option<EditPredictionUsage>)>
2046    where
2047        Res: DeserializeOwned,
2048    {
2049        let http_client = client.http_client();
2050
2051        let mut token = if require_auth {
2052            Some(llm_token.acquire(&client).await?)
2053        } else {
2054            llm_token.acquire(&client).await.ok()
2055        };
2056        let mut did_retry = false;
2057
2058        loop {
2059            let request_builder = http_client::Request::builder().method(Method::POST);
2060
2061            let mut request_builder = request_builder
2062                .header("Content-Type", "application/json")
2063                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2064
2065            // Only add Authorization header if we have a token
2066            if let Some(ref token_value) = token {
2067                request_builder =
2068                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2069            }
2070
2071            let request = build(request_builder)?;
2072
2073            let mut response = http_client.send(request).await?;
2074
2075            if let Some(minimum_required_version) = response
2076                .headers()
2077                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2078                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2079            {
2080                anyhow::ensure!(
2081                    app_version >= minimum_required_version,
2082                    ZedUpdateRequiredError {
2083                        minimum_version: minimum_required_version
2084                    }
2085                );
2086            }
2087
2088            if response.status().is_success() {
2089                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2090
2091                let mut body = Vec::new();
2092                response.body_mut().read_to_end(&mut body).await?;
2093                return Ok((serde_json::from_slice(&body)?, usage));
2094            } else if !did_retry
2095                && token.is_some()
2096                && response
2097                    .headers()
2098                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2099                    .is_some()
2100            {
2101                did_retry = true;
2102                token = Some(llm_token.refresh(&client).await?);
2103            } else {
2104                let mut body = String::new();
2105                response.body_mut().read_to_string(&mut body).await?;
2106                anyhow::bail!(
2107                    "Request failed with status: {:?}\nBody: {}",
2108                    response.status(),
2109                    body
2110                );
2111            }
2112        }
2113    }
2114
2115    pub fn refresh_context(
2116        &mut self,
2117        project: &Entity<Project>,
2118        buffer: &Entity<language::Buffer>,
2119        cursor_position: language::Anchor,
2120        cx: &mut Context<Self>,
2121    ) {
2122        self.get_or_init_project(project, cx)
2123            .context
2124            .update(cx, |store, cx| {
2125                store.refresh(buffer.clone(), cursor_position, cx);
2126            });
2127    }
2128
2129    #[cfg(feature = "cli-support")]
2130    pub fn set_context_for_buffer(
2131        &mut self,
2132        project: &Entity<Project>,
2133        related_files: Vec<RelatedFile>,
2134        cx: &mut Context<Self>,
2135    ) {
2136        self.get_or_init_project(project, cx)
2137            .context
2138            .update(cx, |store, cx| {
2139                store.set_related_files(related_files, cx);
2140            });
2141    }
2142
2143    fn is_file_open_source(
2144        &self,
2145        project: &Entity<Project>,
2146        file: &Arc<dyn File>,
2147        cx: &App,
2148    ) -> bool {
2149        if !file.is_local() || file.is_private() {
2150            return false;
2151        }
2152        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2153            return false;
2154        };
2155        project_state
2156            .license_detection_watchers
2157            .get(&file.worktree_id(cx))
2158            .as_ref()
2159            .is_some_and(|watcher| watcher.is_project_open_source())
2160    }
2161
2162    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2163        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2164    }
2165
2166    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2167        if !self.data_collection_choice.is_enabled(cx) {
2168            return false;
2169        }
2170        events.iter().all(|event| {
2171            matches!(
2172                event.as_ref(),
2173                zeta_prompt::Event::BufferChange {
2174                    in_open_source_repo: true,
2175                    ..
2176                }
2177            )
2178        })
2179    }
2180
2181    fn load_data_collection_choice() -> DataCollectionChoice {
2182        let choice = KEY_VALUE_STORE
2183            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2184            .log_err()
2185            .flatten();
2186
2187        match choice.as_deref() {
2188            Some("true") => DataCollectionChoice::Enabled,
2189            Some("false") => DataCollectionChoice::Disabled,
2190            Some(_) => {
2191                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2192                DataCollectionChoice::NotAnswered
2193            }
2194            None => DataCollectionChoice::NotAnswered,
2195        }
2196    }
2197
2198    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2199        self.data_collection_choice = self.data_collection_choice.toggle();
2200        let new_choice = self.data_collection_choice;
2201        let is_enabled = new_choice.is_enabled(cx);
2202        db::write_and_log(cx, move || {
2203            KEY_VALUE_STORE.write_kvp(
2204                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2205                is_enabled.to_string(),
2206            )
2207        });
2208    }
2209
2210    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2211        self.shown_predictions.iter()
2212    }
2213
2214    pub fn shown_completions_len(&self) -> usize {
2215        self.shown_predictions.len()
2216    }
2217
2218    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2219        self.rated_predictions.contains(id)
2220    }
2221
2222    pub fn rate_prediction(
2223        &mut self,
2224        prediction: &EditPrediction,
2225        rating: EditPredictionRating,
2226        feedback: String,
2227        cx: &mut Context<Self>,
2228    ) {
2229        self.rated_predictions.insert(prediction.id.clone());
2230        telemetry::event!(
2231            "Edit Prediction Rated",
2232            rating,
2233            inputs = prediction.inputs,
2234            output = prediction
2235                .edit_preview
2236                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2237            feedback
2238        );
2239        self.client.telemetry().flush_events().detach();
2240        cx.notify();
2241    }
2242}
2243
2244pub(crate) fn filter_redundant_excerpts(
2245    mut related_files: Vec<RelatedFile>,
2246    cursor_path: &Path,
2247    cursor_row_range: Range<u32>,
2248) -> Vec<RelatedFile> {
2249    for file in &mut related_files {
2250        if file.path.as_ref() == cursor_path {
2251            file.excerpts.retain(|excerpt| {
2252                excerpt.row_range.start < cursor_row_range.start
2253                    || excerpt.row_range.end > cursor_row_range.end
2254            });
2255        }
2256    }
2257    related_files.retain(|file| !file.excerpts.is_empty());
2258    related_files
2259}
2260
2261#[derive(Error, Debug)]
2262#[error(
2263    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2264)]
2265pub struct ZedUpdateRequiredError {
2266    minimum_version: Version,
2267}
2268
2269#[cfg(feature = "cli-support")]
2270pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2271
2272#[cfg(feature = "cli-support")]
2273#[derive(Debug, Clone, Copy, PartialEq)]
2274pub enum EvalCacheEntryKind {
2275    Context,
2276    Search,
2277    Prediction,
2278}
2279
2280#[cfg(feature = "cli-support")]
2281impl std::fmt::Display for EvalCacheEntryKind {
2282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2283        match self {
2284            EvalCacheEntryKind::Search => write!(f, "search"),
2285            EvalCacheEntryKind::Context => write!(f, "context"),
2286            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2287        }
2288    }
2289}
2290
2291#[cfg(feature = "cli-support")]
2292pub trait EvalCache: Send + Sync {
2293    fn read(&self, key: EvalCacheKey) -> Option<String>;
2294    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2295}
2296
2297#[derive(Debug, Clone, Copy)]
2298pub enum DataCollectionChoice {
2299    NotAnswered,
2300    Enabled,
2301    Disabled,
2302}
2303
2304impl DataCollectionChoice {
2305    pub fn is_enabled(self, cx: &App) -> bool {
2306        if cx.is_staff() {
2307            return true;
2308        }
2309        match self {
2310            Self::Enabled => true,
2311            Self::NotAnswered | Self::Disabled => false,
2312        }
2313    }
2314
2315    #[must_use]
2316    pub fn toggle(&self) -> DataCollectionChoice {
2317        match self {
2318            Self::Enabled => Self::Disabled,
2319            Self::Disabled => Self::Enabled,
2320            Self::NotAnswered => Self::Enabled,
2321        }
2322    }
2323}
2324
2325impl From<bool> for DataCollectionChoice {
2326    fn from(value: bool) -> Self {
2327        match value {
2328            true => DataCollectionChoice::Enabled,
2329            false => DataCollectionChoice::Disabled,
2330        }
2331    }
2332}
2333
2334struct ZedPredictUpsell;
2335
2336impl Dismissable for ZedPredictUpsell {
2337    const KEY: &'static str = "dismissed-edit-predict-upsell";
2338
2339    fn dismissed() -> bool {
2340        // To make this backwards compatible with older versions of Zed, we
2341        // check if the user has seen the previous Edit Prediction Onboarding
2342        // before, by checking the data collection choice which was written to
2343        // the database once the user clicked on "Accept and Enable"
2344        if KEY_VALUE_STORE
2345            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2346            .log_err()
2347            .is_some_and(|s| s.is_some())
2348        {
2349            return true;
2350        }
2351
2352        KEY_VALUE_STORE
2353            .read_kvp(Self::KEY)
2354            .log_err()
2355            .is_some_and(|s| s.is_some())
2356    }
2357}
2358
2359pub fn should_show_upsell_modal() -> bool {
2360    !ZedPredictUpsell::dismissed()
2361}
2362
2363pub fn init(cx: &mut App) {
2364    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2365        workspace.register_action(
2366            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2367                ZedPredictModal::toggle(
2368                    workspace,
2369                    workspace.user_store().clone(),
2370                    workspace.client().clone(),
2371                    window,
2372                    cx,
2373                )
2374            },
2375        );
2376
2377        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2378            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2379                settings
2380                    .project
2381                    .all_languages
2382                    .features
2383                    .get_or_insert_default()
2384                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2385            });
2386        });
2387        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2388            EditPredictionStore::try_global(cx).and_then(|store| {
2389                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2390            })
2391        }
2392
2393        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2394            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2395                copilot_ui::initiate_sign_in(copilot, window, cx);
2396            }
2397        });
2398        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2399            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2400                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2401            }
2402        });
2403        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2404            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2405                copilot_ui::initiate_sign_out(copilot, window, cx);
2406            }
2407        });
2408    })
2409    .detach();
2410}