edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use workspace::Workspace;
  38
  39use std::ops::Range;
  40use std::path::Path;
  41use std::rc::Rc;
  42use std::str::FromStr as _;
  43use std::sync::{Arc, LazyLock};
  44use std::time::{Duration, Instant};
  45use std::{env, mem};
  46use thiserror::Error;
  47use util::{RangeExt as _, ResultExt as _};
  48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  49
  50mod cursor_excerpt;
  51mod license_detection;
  52pub mod mercury;
  53mod onboarding_modal;
  54pub mod open_ai_response;
  55mod prediction;
  56pub mod sweep_ai;
  57
  58#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
  59pub mod udiff;
  60
  61mod zed_edit_prediction_delegate;
  62pub mod zeta1;
  63pub mod zeta2;
  64
  65#[cfg(test)]
  66mod edit_prediction_tests;
  67
  68use crate::license_detection::LicenseDetectionWatcher;
  69use crate::mercury::Mercury;
  70use crate::onboarding_modal::ZedPredictModal;
  71pub use crate::prediction::EditPrediction;
  72pub use crate::prediction::EditPredictionId;
  73use crate::prediction::EditPredictionResult;
  74pub use crate::sweep_ai::SweepAi;
  75pub use language_model::ApiKeyState;
  76pub use telemetry_events::EditPredictionRating;
  77pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  78
  79actions!(
  80    edit_prediction,
  81    [
  82        /// Resets the edit prediction onboarding state.
  83        ResetOnboarding,
  84        /// Clears the edit prediction history.
  85        ClearHistory,
  86    ]
  87);
  88
  89/// Maximum number of events to track.
  90const EVENT_COUNT_MAX: usize = 6;
  91const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  92const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  93const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  94
  95pub struct SweepFeatureFlag;
  96
  97impl FeatureFlag for SweepFeatureFlag {
  98    const NAME: &str = "sweep-ai";
  99}
 100
 101pub struct MercuryFeatureFlag;
 102
 103impl FeatureFlag for MercuryFeatureFlag {
 104    const NAME: &str = "mercury";
 105}
 106
 107pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 108    context: EditPredictionExcerptOptions {
 109        max_bytes: 512,
 110        min_bytes: 128,
 111        target_before_cursor_over_total_bytes: 0.5,
 112    },
 113    prompt_format: PromptFormat::DEFAULT,
 114};
 115
 116static USE_OLLAMA: LazyLock<bool> =
 117    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 118
 119static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 120    match env::var("ZED_ZETA2_MODEL").as_deref() {
 121        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 122        Ok(model) => model,
 123        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 124        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 125    }
 126    .to_string()
 127});
 128static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 129    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 130        if *USE_OLLAMA {
 131            Some("http://localhost:11434/v1/chat/completions".into())
 132        } else {
 133            None
 134        }
 135    })
 136});
 137
 138pub struct Zeta2FeatureFlag;
 139
 140impl FeatureFlag for Zeta2FeatureFlag {
 141    const NAME: &'static str = "zeta2";
 142
 143    fn enabled_for_staff() -> bool {
 144        true
 145    }
 146}
 147
 148#[derive(Clone)]
 149struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 150
 151impl Global for EditPredictionStoreGlobal {}
 152
 153pub struct EditPredictionStore {
 154    client: Arc<Client>,
 155    user_store: Entity<UserStore>,
 156    llm_token: LlmApiToken,
 157    _llm_token_subscription: Subscription,
 158    projects: HashMap<EntityId, ProjectState>,
 159    use_context: bool,
 160    options: ZetaOptions,
 161    update_required: bool,
 162    #[cfg(feature = "cli-support")]
 163    eval_cache: Option<Arc<dyn EvalCache>>,
 164    edit_prediction_model: EditPredictionModel,
 165    pub sweep_ai: SweepAi,
 166    pub mercury: Mercury,
 167    data_collection_choice: DataCollectionChoice,
 168    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 169    shown_predictions: VecDeque<EditPrediction>,
 170    rated_predictions: HashSet<EditPredictionId>,
 171}
 172
 173#[derive(Copy, Clone, Default, PartialEq, Eq)]
 174pub enum EditPredictionModel {
 175    #[default]
 176    Zeta1,
 177    Zeta2,
 178    Sweep,
 179    Mercury,
 180}
 181
 182pub struct EditPredictionModelInput {
 183    project: Entity<Project>,
 184    buffer: Entity<Buffer>,
 185    snapshot: BufferSnapshot,
 186    position: Anchor,
 187    events: Vec<Arc<zeta_prompt::Event>>,
 188    related_files: Arc<[RelatedFile]>,
 189    recent_paths: VecDeque<ProjectPath>,
 190    trigger: PredictEditsRequestTrigger,
 191    diagnostic_search_range: Range<Point>,
 192    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 193}
 194
 195#[derive(Debug, Clone, PartialEq)]
 196pub struct ZetaOptions {
 197    pub context: EditPredictionExcerptOptions,
 198    pub prompt_format: predict_edits_v3::PromptFormat,
 199}
 200
 201#[derive(Debug)]
 202pub enum DebugEvent {
 203    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 204    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 205    EditPredictionStarted(EditPredictionStartedDebugEvent),
 206    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 207}
 208
 209#[derive(Debug)]
 210pub struct ContextRetrievalStartedDebugEvent {
 211    pub project_entity_id: EntityId,
 212    pub timestamp: Instant,
 213    pub search_prompt: String,
 214}
 215
 216#[derive(Debug)]
 217pub struct ContextRetrievalFinishedDebugEvent {
 218    pub project_entity_id: EntityId,
 219    pub timestamp: Instant,
 220    pub metadata: Vec<(&'static str, SharedString)>,
 221}
 222
 223#[derive(Debug)]
 224pub struct EditPredictionStartedDebugEvent {
 225    pub buffer: WeakEntity<Buffer>,
 226    pub position: Anchor,
 227    pub prompt: Option<String>,
 228}
 229
 230#[derive(Debug)]
 231pub struct EditPredictionFinishedDebugEvent {
 232    pub buffer: WeakEntity<Buffer>,
 233    pub position: Anchor,
 234    pub model_output: Option<String>,
 235}
 236
 237pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 238
 239struct ProjectState {
 240    events: VecDeque<Arc<zeta_prompt::Event>>,
 241    last_event: Option<LastEvent>,
 242    recent_paths: VecDeque<ProjectPath>,
 243    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 244    current_prediction: Option<CurrentEditPrediction>,
 245    next_pending_prediction_id: usize,
 246    pending_predictions: ArrayVec<PendingPrediction, 2>,
 247    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 248    last_prediction_refresh: Option<(EntityId, Instant)>,
 249    cancelled_predictions: HashSet<usize>,
 250    context: Entity<RelatedExcerptStore>,
 251    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 252    _subscription: gpui::Subscription,
 253}
 254
 255impl ProjectState {
 256    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 257        self.events
 258            .iter()
 259            .cloned()
 260            .chain(
 261                self.last_event
 262                    .as_ref()
 263                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 264            )
 265            .collect()
 266    }
 267
 268    fn cancel_pending_prediction(
 269        &mut self,
 270        pending_prediction: PendingPrediction,
 271        cx: &mut Context<EditPredictionStore>,
 272    ) {
 273        self.cancelled_predictions.insert(pending_prediction.id);
 274
 275        cx.spawn(async move |this, cx| {
 276            let Some(prediction_id) = pending_prediction.task.await else {
 277                return;
 278            };
 279
 280            this.update(cx, |this, _cx| {
 281                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 282            })
 283            .ok();
 284        })
 285        .detach()
 286    }
 287
 288    fn active_buffer(
 289        &self,
 290        project: &Entity<Project>,
 291        cx: &App,
 292    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 293        let project = project.read(cx);
 294        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 295        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 296        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 297        Some((active_buffer, registered_buffer.last_position))
 298    }
 299}
 300
 301#[derive(Debug, Clone)]
 302struct CurrentEditPrediction {
 303    pub requested_by: PredictionRequestedBy,
 304    pub prediction: EditPrediction,
 305    pub was_shown: bool,
 306}
 307
 308impl CurrentEditPrediction {
 309    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 310        let Some(new_edits) = self
 311            .prediction
 312            .interpolate(&self.prediction.buffer.read(cx))
 313        else {
 314            return false;
 315        };
 316
 317        if self.prediction.buffer != old_prediction.prediction.buffer {
 318            return true;
 319        }
 320
 321        let Some(old_edits) = old_prediction
 322            .prediction
 323            .interpolate(&old_prediction.prediction.buffer.read(cx))
 324        else {
 325            return true;
 326        };
 327
 328        let requested_by_buffer_id = self.requested_by.buffer_id();
 329
 330        // This reduces the occurrence of UI thrash from replacing edits
 331        //
 332        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 333        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 334            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 335            && old_edits.len() == 1
 336            && new_edits.len() == 1
 337        {
 338            let (old_range, old_text) = &old_edits[0];
 339            let (new_range, new_text) = &new_edits[0];
 340            new_range == old_range && new_text.starts_with(old_text.as_ref())
 341        } else {
 342            true
 343        }
 344    }
 345}
 346
 347#[derive(Debug, Clone)]
 348enum PredictionRequestedBy {
 349    DiagnosticsUpdate,
 350    Buffer(EntityId),
 351}
 352
 353impl PredictionRequestedBy {
 354    pub fn buffer_id(&self) -> Option<EntityId> {
 355        match self {
 356            PredictionRequestedBy::DiagnosticsUpdate => None,
 357            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 358        }
 359    }
 360}
 361
 362#[derive(Debug)]
 363struct PendingPrediction {
 364    id: usize,
 365    task: Task<Option<EditPredictionId>>,
 366}
 367
 368/// A prediction from the perspective of a buffer.
 369#[derive(Debug)]
 370enum BufferEditPrediction<'a> {
 371    Local { prediction: &'a EditPrediction },
 372    Jump { prediction: &'a EditPrediction },
 373}
 374
 375#[cfg(test)]
 376impl std::ops::Deref for BufferEditPrediction<'_> {
 377    type Target = EditPrediction;
 378
 379    fn deref(&self) -> &Self::Target {
 380        match self {
 381            BufferEditPrediction::Local { prediction } => prediction,
 382            BufferEditPrediction::Jump { prediction } => prediction,
 383        }
 384    }
 385}
 386
 387struct RegisteredBuffer {
 388    snapshot: BufferSnapshot,
 389    last_position: Option<Anchor>,
 390    _subscriptions: [gpui::Subscription; 2],
 391}
 392
 393struct LastEvent {
 394    old_snapshot: BufferSnapshot,
 395    new_snapshot: BufferSnapshot,
 396    end_edit_anchor: Option<Anchor>,
 397}
 398
 399impl LastEvent {
 400    pub fn finalize(
 401        &self,
 402        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 403        cx: &App,
 404    ) -> Option<Arc<zeta_prompt::Event>> {
 405        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 406        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 407
 408        let file = self.new_snapshot.file();
 409        let old_file = self.old_snapshot.file();
 410
 411        let in_open_source_repo = [file, old_file].iter().all(|file| {
 412            file.is_some_and(|file| {
 413                license_detection_watchers
 414                    .get(&file.worktree_id(cx))
 415                    .is_some_and(|watcher| watcher.is_project_open_source())
 416            })
 417        });
 418
 419        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 420
 421        if path == old_path && diff.is_empty() {
 422            None
 423        } else {
 424            Some(Arc::new(zeta_prompt::Event::BufferChange {
 425                old_path,
 426                path,
 427                diff,
 428                in_open_source_repo,
 429                // TODO: Actually detect if this edit was predicted or not
 430                predicted: false,
 431            }))
 432        }
 433    }
 434}
 435
 436fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 437    if let Some(file) = snapshot.file() {
 438        file.full_path(cx).into()
 439    } else {
 440        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 441    }
 442}
 443
 444impl EditPredictionStore {
 445    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 446        cx.try_global::<EditPredictionStoreGlobal>()
 447            .map(|global| global.0.clone())
 448    }
 449
 450    pub fn global(
 451        client: &Arc<Client>,
 452        user_store: &Entity<UserStore>,
 453        cx: &mut App,
 454    ) -> Entity<Self> {
 455        cx.try_global::<EditPredictionStoreGlobal>()
 456            .map(|global| global.0.clone())
 457            .unwrap_or_else(|| {
 458                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 459                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 460                ep_store
 461            })
 462    }
 463
 464    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 465        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 466        let data_collection_choice = Self::load_data_collection_choice();
 467
 468        let llm_token = LlmApiToken::default();
 469
 470        let (reject_tx, reject_rx) = mpsc::unbounded();
 471        cx.background_spawn({
 472            let client = client.clone();
 473            let llm_token = llm_token.clone();
 474            let app_version = AppVersion::global(cx);
 475            let background_executor = cx.background_executor().clone();
 476            async move {
 477                Self::handle_rejected_predictions(
 478                    reject_rx,
 479                    client,
 480                    llm_token,
 481                    app_version,
 482                    background_executor,
 483                )
 484                .await
 485            }
 486        })
 487        .detach();
 488
 489        let mut this = Self {
 490            projects: HashMap::default(),
 491            client,
 492            user_store,
 493            options: DEFAULT_OPTIONS,
 494            use_context: false,
 495            llm_token,
 496            _llm_token_subscription: cx.subscribe(
 497                &refresh_llm_token_listener,
 498                |this, _listener, _event, cx| {
 499                    let client = this.client.clone();
 500                    let llm_token = this.llm_token.clone();
 501                    cx.spawn(async move |_this, _cx| {
 502                        llm_token.refresh(&client).await?;
 503                        anyhow::Ok(())
 504                    })
 505                    .detach_and_log_err(cx);
 506                },
 507            ),
 508            update_required: false,
 509            #[cfg(feature = "cli-support")]
 510            eval_cache: None,
 511            edit_prediction_model: EditPredictionModel::Zeta2,
 512            sweep_ai: SweepAi::new(cx),
 513            mercury: Mercury::new(cx),
 514            data_collection_choice,
 515            reject_predictions_tx: reject_tx,
 516            rated_predictions: Default::default(),
 517            shown_predictions: Default::default(),
 518        };
 519
 520        this.configure_context_retrieval(cx);
 521        let weak_this = cx.weak_entity();
 522        cx.on_flags_ready(move |_, cx| {
 523            weak_this
 524                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 525                .ok();
 526        })
 527        .detach();
 528        cx.observe_global::<SettingsStore>(|this, cx| {
 529            this.configure_context_retrieval(cx);
 530        })
 531        .detach();
 532
 533        this
 534    }
 535
 536    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 537        self.edit_prediction_model = model;
 538    }
 539
 540    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 541        self.sweep_ai.api_token.read(cx).has_key()
 542    }
 543
 544    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 545        self.mercury.api_token.read(cx).has_key()
 546    }
 547
 548    #[cfg(feature = "cli-support")]
 549    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 550        self.eval_cache = Some(cache);
 551    }
 552
 553    pub fn options(&self) -> &ZetaOptions {
 554        &self.options
 555    }
 556
 557    pub fn set_options(&mut self, options: ZetaOptions) {
 558        self.options = options;
 559    }
 560
 561    pub fn set_use_context(&mut self, use_context: bool) {
 562        self.use_context = use_context;
 563    }
 564
 565    pub fn clear_history(&mut self) {
 566        for project_state in self.projects.values_mut() {
 567            project_state.events.clear();
 568        }
 569    }
 570
 571    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 572        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 573            project_state.events.clear();
 574        }
 575    }
 576
 577    pub fn edit_history_for_project(
 578        &self,
 579        project: &Entity<Project>,
 580        cx: &App,
 581    ) -> Vec<Arc<zeta_prompt::Event>> {
 582        self.projects
 583            .get(&project.entity_id())
 584            .map(|project_state| project_state.events(cx))
 585            .unwrap_or_default()
 586    }
 587
 588    pub fn context_for_project<'a>(
 589        &'a self,
 590        project: &Entity<Project>,
 591        cx: &'a App,
 592    ) -> Arc<[RelatedFile]> {
 593        self.projects
 594            .get(&project.entity_id())
 595            .map(|project| project.context.read(cx).related_files())
 596            .unwrap_or_else(|| vec![].into())
 597    }
 598
 599    pub fn context_for_project_with_buffers<'a>(
 600        &'a self,
 601        project: &Entity<Project>,
 602        cx: &'a App,
 603    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 604        self.projects
 605            .get(&project.entity_id())
 606            .map(|project| project.context.read(cx).related_files_with_buffers())
 607    }
 608
 609    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 610        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 611            self.user_store.read(cx).edit_prediction_usage()
 612        } else {
 613            None
 614        }
 615    }
 616
 617    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 618        self.get_or_init_project(project, cx);
 619    }
 620
 621    pub fn register_buffer(
 622        &mut self,
 623        buffer: &Entity<Buffer>,
 624        project: &Entity<Project>,
 625        cx: &mut Context<Self>,
 626    ) {
 627        let project_state = self.get_or_init_project(project, cx);
 628        Self::register_buffer_impl(project_state, buffer, project, cx);
 629    }
 630
 631    fn get_or_init_project(
 632        &mut self,
 633        project: &Entity<Project>,
 634        cx: &mut Context<Self>,
 635    ) -> &mut ProjectState {
 636        let entity_id = project.entity_id();
 637        self.projects
 638            .entry(entity_id)
 639            .or_insert_with(|| ProjectState {
 640                context: {
 641                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 642                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 643                        this.handle_excerpt_store_event(entity_id, event);
 644                    })
 645                    .detach();
 646                    related_excerpt_store
 647                },
 648                events: VecDeque::new(),
 649                last_event: None,
 650                recent_paths: VecDeque::new(),
 651                debug_tx: None,
 652                registered_buffers: HashMap::default(),
 653                current_prediction: None,
 654                cancelled_predictions: HashSet::default(),
 655                pending_predictions: ArrayVec::new(),
 656                next_pending_prediction_id: 0,
 657                last_prediction_refresh: None,
 658                license_detection_watchers: HashMap::default(),
 659                _subscription: cx.subscribe(&project, Self::handle_project_event),
 660            })
 661    }
 662
 663    pub fn remove_project(&mut self, project: &Entity<Project>) {
 664        self.projects.remove(&project.entity_id());
 665    }
 666
 667    fn handle_excerpt_store_event(
 668        &mut self,
 669        project_entity_id: EntityId,
 670        event: &RelatedExcerptStoreEvent,
 671    ) {
 672        if let Some(project_state) = self.projects.get(&project_entity_id) {
 673            if let Some(debug_tx) = project_state.debug_tx.clone() {
 674                match event {
 675                    RelatedExcerptStoreEvent::StartedRefresh => {
 676                        debug_tx
 677                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 678                                ContextRetrievalStartedDebugEvent {
 679                                    project_entity_id: project_entity_id,
 680                                    timestamp: Instant::now(),
 681                                    search_prompt: String::new(),
 682                                },
 683                            ))
 684                            .ok();
 685                    }
 686                    RelatedExcerptStoreEvent::FinishedRefresh {
 687                        cache_hit_count,
 688                        cache_miss_count,
 689                        mean_definition_latency,
 690                        max_definition_latency,
 691                    } => {
 692                        debug_tx
 693                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 694                                ContextRetrievalFinishedDebugEvent {
 695                                    project_entity_id: project_entity_id,
 696                                    timestamp: Instant::now(),
 697                                    metadata: vec![
 698                                        (
 699                                            "Cache Hits",
 700                                            format!(
 701                                                "{}/{}",
 702                                                cache_hit_count,
 703                                                cache_hit_count + cache_miss_count
 704                                            )
 705                                            .into(),
 706                                        ),
 707                                        (
 708                                            "Max LSP Time",
 709                                            format!("{} ms", max_definition_latency.as_millis())
 710                                                .into(),
 711                                        ),
 712                                        (
 713                                            "Mean LSP Time",
 714                                            format!("{} ms", mean_definition_latency.as_millis())
 715                                                .into(),
 716                                        ),
 717                                    ],
 718                                },
 719                            ))
 720                            .ok();
 721                    }
 722                }
 723            }
 724        }
 725    }
 726
 727    pub fn debug_info(
 728        &mut self,
 729        project: &Entity<Project>,
 730        cx: &mut Context<Self>,
 731    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 732        let project_state = self.get_or_init_project(project, cx);
 733        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 734        project_state.debug_tx = Some(debug_watch_tx);
 735        debug_watch_rx
 736    }
 737
 738    fn handle_project_event(
 739        &mut self,
 740        project: Entity<Project>,
 741        event: &project::Event,
 742        cx: &mut Context<Self>,
 743    ) {
 744        // TODO [zeta2] init with recent paths
 745        match event {
 746            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 747                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 748                    return;
 749                };
 750                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 751                if let Some(path) = path {
 752                    if let Some(ix) = project_state
 753                        .recent_paths
 754                        .iter()
 755                        .position(|probe| probe == &path)
 756                    {
 757                        project_state.recent_paths.remove(ix);
 758                    }
 759                    project_state.recent_paths.push_front(path);
 760                }
 761            }
 762            project::Event::DiagnosticsUpdated { .. } => {
 763                if cx.has_flag::<Zeta2FeatureFlag>() {
 764                    self.refresh_prediction_from_diagnostics(project, cx);
 765                }
 766            }
 767            _ => (),
 768        }
 769    }
 770
 771    fn register_buffer_impl<'a>(
 772        project_state: &'a mut ProjectState,
 773        buffer: &Entity<Buffer>,
 774        project: &Entity<Project>,
 775        cx: &mut Context<Self>,
 776    ) -> &'a mut RegisteredBuffer {
 777        let buffer_id = buffer.entity_id();
 778
 779        if let Some(file) = buffer.read(cx).file() {
 780            let worktree_id = file.worktree_id(cx);
 781            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 782                project_state
 783                    .license_detection_watchers
 784                    .entry(worktree_id)
 785                    .or_insert_with(|| {
 786                        let project_entity_id = project.entity_id();
 787                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 788                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 789                            else {
 790                                return;
 791                            };
 792                            project_state
 793                                .license_detection_watchers
 794                                .remove(&worktree_id);
 795                        })
 796                        .detach();
 797                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 798                    });
 799            }
 800        }
 801
 802        match project_state.registered_buffers.entry(buffer_id) {
 803            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 804            hash_map::Entry::Vacant(entry) => {
 805                let snapshot = buffer.read(cx).snapshot();
 806                let project_entity_id = project.entity_id();
 807                entry.insert(RegisteredBuffer {
 808                    snapshot,
 809                    last_position: None,
 810                    _subscriptions: [
 811                        cx.subscribe(buffer, {
 812                            let project = project.downgrade();
 813                            move |this, buffer, event, cx| {
 814                                if let language::BufferEvent::Edited = event
 815                                    && let Some(project) = project.upgrade()
 816                                {
 817                                    this.report_changes_for_buffer(&buffer, &project, cx);
 818                                }
 819                            }
 820                        }),
 821                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 822                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 823                            else {
 824                                return;
 825                            };
 826                            project_state.registered_buffers.remove(&buffer_id);
 827                        }),
 828                    ],
 829                })
 830            }
 831        }
 832    }
 833
 834    fn report_changes_for_buffer(
 835        &mut self,
 836        buffer: &Entity<Buffer>,
 837        project: &Entity<Project>,
 838        cx: &mut Context<Self>,
 839    ) {
 840        let project_state = self.get_or_init_project(project, cx);
 841        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 842
 843        let new_snapshot = buffer.read(cx).snapshot();
 844        if new_snapshot.version == registered_buffer.snapshot.version {
 845            return;
 846        }
 847
 848        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 849        let end_edit_anchor = new_snapshot
 850            .anchored_edits_since::<Point>(&old_snapshot.version)
 851            .last()
 852            .map(|(_, range)| range.end);
 853        let events = &mut project_state.events;
 854
 855        if let Some(LastEvent {
 856            new_snapshot: last_new_snapshot,
 857            end_edit_anchor: last_end_edit_anchor,
 858            ..
 859        }) = project_state.last_event.as_mut()
 860        {
 861            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 862                == last_new_snapshot.remote_id()
 863                && old_snapshot.version == last_new_snapshot.version;
 864
 865            let should_coalesce = is_next_snapshot_of_same_buffer
 866                && end_edit_anchor
 867                    .as_ref()
 868                    .zip(last_end_edit_anchor.as_ref())
 869                    .is_some_and(|(a, b)| {
 870                        let a = a.to_point(&new_snapshot);
 871                        let b = b.to_point(&new_snapshot);
 872                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 873                    });
 874
 875            if should_coalesce {
 876                *last_end_edit_anchor = end_edit_anchor;
 877                *last_new_snapshot = new_snapshot;
 878                return;
 879            }
 880        }
 881
 882        if events.len() + 1 >= EVENT_COUNT_MAX {
 883            events.pop_front();
 884        }
 885
 886        if let Some(event) = project_state.last_event.take() {
 887            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 888        }
 889
 890        project_state.last_event = Some(LastEvent {
 891            old_snapshot,
 892            new_snapshot,
 893            end_edit_anchor,
 894        });
 895    }
 896
 897    fn prediction_at(
 898        &mut self,
 899        buffer: &Entity<Buffer>,
 900        position: Option<language::Anchor>,
 901        project: &Entity<Project>,
 902        cx: &App,
 903    ) -> Option<BufferEditPrediction<'_>> {
 904        let project_state = self.projects.get_mut(&project.entity_id())?;
 905        if let Some(position) = position
 906            && let Some(buffer) = project_state
 907                .registered_buffers
 908                .get_mut(&buffer.entity_id())
 909        {
 910            buffer.last_position = Some(position);
 911        }
 912
 913        let CurrentEditPrediction {
 914            requested_by,
 915            prediction,
 916            ..
 917        } = project_state.current_prediction.as_ref()?;
 918
 919        if prediction.targets_buffer(buffer.read(cx)) {
 920            Some(BufferEditPrediction::Local { prediction })
 921        } else {
 922            let show_jump = match requested_by {
 923                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 924                    requested_by_buffer_id == &buffer.entity_id()
 925                }
 926                PredictionRequestedBy::DiagnosticsUpdate => true,
 927            };
 928
 929            if show_jump {
 930                Some(BufferEditPrediction::Jump { prediction })
 931            } else {
 932                None
 933            }
 934        }
 935    }
 936
 937    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 938        match self.edit_prediction_model {
 939            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
 940            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
 941        }
 942
 943        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 944            return;
 945        };
 946
 947        let Some(prediction) = project_state.current_prediction.take() else {
 948            return;
 949        };
 950        let request_id = prediction.prediction.id.to_string();
 951        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 952            project_state.cancel_pending_prediction(pending_prediction, cx);
 953        }
 954
 955        let client = self.client.clone();
 956        let llm_token = self.llm_token.clone();
 957        let app_version = AppVersion::global(cx);
 958        cx.spawn(async move |this, cx| {
 959            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 960                http_client::Url::parse(&predict_edits_url)?
 961            } else {
 962                client
 963                    .http_client()
 964                    .build_zed_llm_url("/predict_edits/accept", &[])?
 965            };
 966
 967            let response = cx
 968                .background_spawn(Self::send_api_request::<()>(
 969                    move |builder| {
 970                        let req = builder.uri(url.as_ref()).body(
 971                            serde_json::to_string(&AcceptEditPredictionBody {
 972                                request_id: request_id.clone(),
 973                            })?
 974                            .into(),
 975                        );
 976                        Ok(req?)
 977                    },
 978                    client,
 979                    llm_token,
 980                    app_version,
 981                ))
 982                .await;
 983
 984            Self::handle_api_response(&this, response, cx)?;
 985            anyhow::Ok(())
 986        })
 987        .detach_and_log_err(cx);
 988    }
 989
 990    async fn handle_rejected_predictions(
 991        rx: UnboundedReceiver<EditPredictionRejection>,
 992        client: Arc<Client>,
 993        llm_token: LlmApiToken,
 994        app_version: Version,
 995        background_executor: BackgroundExecutor,
 996    ) {
 997        let mut rx = std::pin::pin!(rx.peekable());
 998        let mut batched = Vec::new();
 999
1000        while let Some(rejection) = rx.next().await {
1001            batched.push(rejection);
1002
1003            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1004                select_biased! {
1005                    next = rx.as_mut().peek().fuse() => {
1006                        if next.is_some() {
1007                            continue;
1008                        }
1009                    }
1010                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1011                }
1012            }
1013
1014            let url = client
1015                .http_client()
1016                .build_zed_llm_url("/predict_edits/reject", &[])
1017                .unwrap();
1018
1019            let flush_count = batched
1020                .len()
1021                // in case items have accumulated after failure
1022                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1023            let start = batched.len() - flush_count;
1024
1025            let body = RejectEditPredictionsBodyRef {
1026                rejections: &batched[start..],
1027            };
1028
1029            let result = Self::send_api_request::<()>(
1030                |builder| {
1031                    let req = builder
1032                        .uri(url.as_ref())
1033                        .body(serde_json::to_string(&body)?.into());
1034                    anyhow::Ok(req?)
1035                },
1036                client.clone(),
1037                llm_token.clone(),
1038                app_version.clone(),
1039            )
1040            .await;
1041
1042            if result.log_err().is_some() {
1043                batched.drain(start..);
1044            }
1045        }
1046    }
1047
1048    fn reject_current_prediction(
1049        &mut self,
1050        reason: EditPredictionRejectReason,
1051        project: &Entity<Project>,
1052    ) {
1053        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1054            project_state.pending_predictions.clear();
1055            if let Some(prediction) = project_state.current_prediction.take() {
1056                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1057            }
1058        };
1059    }
1060
1061    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1062        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1063            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1064                if !current_prediction.was_shown {
1065                    current_prediction.was_shown = true;
1066                    self.shown_predictions
1067                        .push_front(current_prediction.prediction.clone());
1068                    if self.shown_predictions.len() > 50 {
1069                        let completion = self.shown_predictions.pop_back().unwrap();
1070                        self.rated_predictions.remove(&completion.id);
1071                    }
1072                }
1073            }
1074        }
1075    }
1076
1077    fn reject_prediction(
1078        &mut self,
1079        prediction_id: EditPredictionId,
1080        reason: EditPredictionRejectReason,
1081        was_shown: bool,
1082    ) {
1083        match self.edit_prediction_model {
1084            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1085            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1086        }
1087
1088        self.reject_predictions_tx
1089            .unbounded_send(EditPredictionRejection {
1090                request_id: prediction_id.to_string(),
1091                reason,
1092                was_shown,
1093            })
1094            .log_err();
1095    }
1096
1097    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1098        self.projects
1099            .get(&project.entity_id())
1100            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1101    }
1102
1103    pub fn refresh_prediction_from_buffer(
1104        &mut self,
1105        project: Entity<Project>,
1106        buffer: Entity<Buffer>,
1107        position: language::Anchor,
1108        cx: &mut Context<Self>,
1109    ) {
1110        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1111            let Some(request_task) = this
1112                .update(cx, |this, cx| {
1113                    this.request_prediction(
1114                        &project,
1115                        &buffer,
1116                        position,
1117                        PredictEditsRequestTrigger::Other,
1118                        cx,
1119                    )
1120                })
1121                .log_err()
1122            else {
1123                return Task::ready(anyhow::Ok(None));
1124            };
1125
1126            cx.spawn(async move |_cx| {
1127                request_task.await.map(|prediction_result| {
1128                    prediction_result.map(|prediction_result| {
1129                        (
1130                            prediction_result,
1131                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1132                        )
1133                    })
1134                })
1135            })
1136        })
1137    }
1138
1139    pub fn refresh_prediction_from_diagnostics(
1140        &mut self,
1141        project: Entity<Project>,
1142        cx: &mut Context<Self>,
1143    ) {
1144        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1145            return;
1146        };
1147
1148        // Prefer predictions from buffer
1149        if project_state.current_prediction.is_some() {
1150            return;
1151        };
1152
1153        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1154            let Some((active_buffer, snapshot, cursor_point)) = this
1155                .read_with(cx, |this, cx| {
1156                    let project_state = this.projects.get(&project.entity_id())?;
1157                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1158                    let snapshot = buffer.read(cx).snapshot();
1159
1160                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1161                        return None;
1162                    }
1163
1164                    let cursor_point = position
1165                        .map(|pos| pos.to_point(&snapshot))
1166                        .unwrap_or_default();
1167
1168                    Some((buffer, snapshot, cursor_point))
1169                })
1170                .log_err()
1171                .flatten()
1172            else {
1173                return Task::ready(anyhow::Ok(None));
1174            };
1175
1176            cx.spawn(async move |cx| {
1177                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1178                    active_buffer,
1179                    &snapshot,
1180                    Default::default(),
1181                    cursor_point,
1182                    &project,
1183                    cx,
1184                )
1185                .await?
1186                else {
1187                    return anyhow::Ok(None);
1188                };
1189
1190                let Some(prediction_result) = this
1191                    .update(cx, |this, cx| {
1192                        this.request_prediction(
1193                            &project,
1194                            &jump_buffer,
1195                            jump_position,
1196                            PredictEditsRequestTrigger::Diagnostics,
1197                            cx,
1198                        )
1199                    })?
1200                    .await?
1201                else {
1202                    return anyhow::Ok(None);
1203                };
1204
1205                this.update(cx, |this, cx| {
1206                    Some((
1207                        if this
1208                            .get_or_init_project(&project, cx)
1209                            .current_prediction
1210                            .is_none()
1211                        {
1212                            prediction_result
1213                        } else {
1214                            EditPredictionResult {
1215                                id: prediction_result.id,
1216                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1217                            }
1218                        },
1219                        PredictionRequestedBy::DiagnosticsUpdate,
1220                    ))
1221                })
1222            })
1223        });
1224    }
1225
1226    fn predictions_enabled_at(
1227        snapshot: &BufferSnapshot,
1228        position: Option<language::Anchor>,
1229        cx: &App,
1230    ) -> bool {
1231        let file = snapshot.file();
1232        let all_settings = all_language_settings(file, cx);
1233        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1234            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1235        {
1236            return false;
1237        }
1238
1239        if let Some(last_position) = position {
1240            let settings = snapshot.settings_at(last_position, cx);
1241
1242            if !settings.edit_predictions_disabled_in.is_empty()
1243                && let Some(scope) = snapshot.language_scope_at(last_position)
1244                && let Some(scope_name) = scope.override_name()
1245                && settings
1246                    .edit_predictions_disabled_in
1247                    .iter()
1248                    .any(|s| s == scope_name)
1249            {
1250                return false;
1251            }
1252        }
1253
1254        true
1255    }
1256
1257    #[cfg(not(test))]
1258    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1259    #[cfg(test)]
1260    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1261
1262    fn queue_prediction_refresh(
1263        &mut self,
1264        project: Entity<Project>,
1265        throttle_entity: EntityId,
1266        cx: &mut Context<Self>,
1267        do_refresh: impl FnOnce(
1268            WeakEntity<Self>,
1269            &mut AsyncApp,
1270        )
1271            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1272        + 'static,
1273    ) {
1274        let project_state = self.get_or_init_project(&project, cx);
1275        let pending_prediction_id = project_state.next_pending_prediction_id;
1276        project_state.next_pending_prediction_id += 1;
1277        let last_request = project_state.last_prediction_refresh;
1278
1279        let task = cx.spawn(async move |this, cx| {
1280            if let Some((last_entity, last_timestamp)) = last_request
1281                && throttle_entity == last_entity
1282                && let Some(timeout) =
1283                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1284            {
1285                cx.background_executor().timer(timeout).await;
1286            }
1287
1288            // If this task was cancelled before the throttle timeout expired,
1289            // do not perform a request.
1290            let mut is_cancelled = true;
1291            this.update(cx, |this, cx| {
1292                let project_state = this.get_or_init_project(&project, cx);
1293                if !project_state
1294                    .cancelled_predictions
1295                    .remove(&pending_prediction_id)
1296                {
1297                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1298                    is_cancelled = false;
1299                }
1300            })
1301            .ok();
1302            if is_cancelled {
1303                return None;
1304            }
1305
1306            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1307            let new_prediction_id = new_prediction_result
1308                .as_ref()
1309                .map(|(prediction, _)| prediction.id.clone());
1310
1311            // When a prediction completes, remove it from the pending list, and cancel
1312            // any pending predictions that were enqueued before it.
1313            this.update(cx, |this, cx| {
1314                let project_state = this.get_or_init_project(&project, cx);
1315
1316                let is_cancelled = project_state
1317                    .cancelled_predictions
1318                    .remove(&pending_prediction_id);
1319
1320                let new_current_prediction = if !is_cancelled
1321                    && let Some((prediction_result, requested_by)) = new_prediction_result
1322                {
1323                    match prediction_result.prediction {
1324                        Ok(prediction) => {
1325                            let new_prediction = CurrentEditPrediction {
1326                                requested_by,
1327                                prediction,
1328                                was_shown: false,
1329                            };
1330
1331                            if let Some(current_prediction) =
1332                                project_state.current_prediction.as_ref()
1333                            {
1334                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1335                                {
1336                                    this.reject_current_prediction(
1337                                        EditPredictionRejectReason::Replaced,
1338                                        &project,
1339                                    );
1340
1341                                    Some(new_prediction)
1342                                } else {
1343                                    this.reject_prediction(
1344                                        new_prediction.prediction.id,
1345                                        EditPredictionRejectReason::CurrentPreferred,
1346                                        false,
1347                                    );
1348                                    None
1349                                }
1350                            } else {
1351                                Some(new_prediction)
1352                            }
1353                        }
1354                        Err(reject_reason) => {
1355                            this.reject_prediction(prediction_result.id, reject_reason, false);
1356                            None
1357                        }
1358                    }
1359                } else {
1360                    None
1361                };
1362
1363                let project_state = this.get_or_init_project(&project, cx);
1364
1365                if let Some(new_prediction) = new_current_prediction {
1366                    project_state.current_prediction = Some(new_prediction);
1367                }
1368
1369                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1370                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1371                    if pending_prediction.id == pending_prediction_id {
1372                        pending_predictions.remove(ix);
1373                        for pending_prediction in pending_predictions.drain(0..ix) {
1374                            project_state.cancel_pending_prediction(pending_prediction, cx)
1375                        }
1376                        break;
1377                    }
1378                }
1379                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1380                cx.notify();
1381            })
1382            .ok();
1383
1384            new_prediction_id
1385        });
1386
1387        if project_state.pending_predictions.len() <= 1 {
1388            project_state.pending_predictions.push(PendingPrediction {
1389                id: pending_prediction_id,
1390                task,
1391            });
1392        } else if project_state.pending_predictions.len() == 2 {
1393            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1394            project_state.pending_predictions.push(PendingPrediction {
1395                id: pending_prediction_id,
1396                task,
1397            });
1398            project_state.cancel_pending_prediction(pending_prediction, cx);
1399        }
1400    }
1401
1402    pub fn request_prediction(
1403        &mut self,
1404        project: &Entity<Project>,
1405        active_buffer: &Entity<Buffer>,
1406        position: language::Anchor,
1407        trigger: PredictEditsRequestTrigger,
1408        cx: &mut Context<Self>,
1409    ) -> Task<Result<Option<EditPredictionResult>>> {
1410        self.request_prediction_internal(
1411            project.clone(),
1412            active_buffer.clone(),
1413            position,
1414            trigger,
1415            cx.has_flag::<Zeta2FeatureFlag>(),
1416            cx,
1417        )
1418    }
1419
1420    fn request_prediction_internal(
1421        &mut self,
1422        project: Entity<Project>,
1423        active_buffer: Entity<Buffer>,
1424        position: language::Anchor,
1425        trigger: PredictEditsRequestTrigger,
1426        allow_jump: bool,
1427        cx: &mut Context<Self>,
1428    ) -> Task<Result<Option<EditPredictionResult>>> {
1429        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1430
1431        self.get_or_init_project(&project, cx);
1432        let project_state = self.projects.get(&project.entity_id()).unwrap();
1433        let events = project_state.events(cx);
1434        let has_events = !events.is_empty();
1435        let debug_tx = project_state.debug_tx.clone();
1436
1437        let snapshot = active_buffer.read(cx).snapshot();
1438        let cursor_point = position.to_point(&snapshot);
1439        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1440        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1441        let diagnostic_search_range =
1442            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1443
1444        let related_files = if self.use_context {
1445            self.context_for_project(&project, cx)
1446        } else {
1447            Vec::new().into()
1448        };
1449
1450        let inputs = EditPredictionModelInput {
1451            project: project.clone(),
1452            buffer: active_buffer.clone(),
1453            snapshot: snapshot.clone(),
1454            position,
1455            events,
1456            related_files,
1457            recent_paths: project_state.recent_paths.clone(),
1458            trigger,
1459            diagnostic_search_range: diagnostic_search_range.clone(),
1460            debug_tx,
1461        };
1462
1463        let task = match self.edit_prediction_model {
1464            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1465            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1466            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1467            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1468        };
1469
1470        cx.spawn(async move |this, cx| {
1471            let prediction = task.await?;
1472
1473            if prediction.is_none() && allow_jump {
1474                let cursor_point = position.to_point(&snapshot);
1475                if has_events
1476                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1477                        active_buffer.clone(),
1478                        &snapshot,
1479                        diagnostic_search_range,
1480                        cursor_point,
1481                        &project,
1482                        cx,
1483                    )
1484                    .await?
1485                {
1486                    return this
1487                        .update(cx, |this, cx| {
1488                            this.request_prediction_internal(
1489                                project,
1490                                jump_buffer,
1491                                jump_position,
1492                                trigger,
1493                                false,
1494                                cx,
1495                            )
1496                        })?
1497                        .await;
1498                }
1499
1500                return anyhow::Ok(None);
1501            }
1502
1503            Ok(prediction)
1504        })
1505    }
1506
1507    async fn next_diagnostic_location(
1508        active_buffer: Entity<Buffer>,
1509        active_buffer_snapshot: &BufferSnapshot,
1510        active_buffer_diagnostic_search_range: Range<Point>,
1511        active_buffer_cursor_point: Point,
1512        project: &Entity<Project>,
1513        cx: &mut AsyncApp,
1514    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1515        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1516        let mut jump_location = active_buffer_snapshot
1517            .diagnostic_groups(None)
1518            .into_iter()
1519            .filter_map(|(_, group)| {
1520                let range = &group.entries[group.primary_ix]
1521                    .range
1522                    .to_point(&active_buffer_snapshot);
1523                if range.overlaps(&active_buffer_diagnostic_search_range) {
1524                    None
1525                } else {
1526                    Some(range.start)
1527                }
1528            })
1529            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1530            .map(|position| {
1531                (
1532                    active_buffer.clone(),
1533                    active_buffer_snapshot.anchor_before(position),
1534                )
1535            });
1536
1537        if jump_location.is_none() {
1538            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1539                let file = buffer.file()?;
1540
1541                Some(ProjectPath {
1542                    worktree_id: file.worktree_id(cx),
1543                    path: file.path().clone(),
1544                })
1545            })?;
1546
1547            let buffer_task = project.update(cx, |project, cx| {
1548                let (path, _, _) = project
1549                    .diagnostic_summaries(false, cx)
1550                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1551                    .max_by_key(|(path, _, _)| {
1552                        // find the buffer with errors that shares most parent directories
1553                        path.path
1554                            .components()
1555                            .zip(
1556                                active_buffer_path
1557                                    .as_ref()
1558                                    .map(|p| p.path.components())
1559                                    .unwrap_or_default(),
1560                            )
1561                            .take_while(|(a, b)| a == b)
1562                            .count()
1563                    })?;
1564
1565                Some(project.open_buffer(path, cx))
1566            })?;
1567
1568            if let Some(buffer_task) = buffer_task {
1569                let closest_buffer = buffer_task.await?;
1570
1571                jump_location = closest_buffer
1572                    .read_with(cx, |buffer, _cx| {
1573                        buffer
1574                            .buffer_diagnostics(None)
1575                            .into_iter()
1576                            .min_by_key(|entry| entry.diagnostic.severity)
1577                            .map(|entry| entry.range.start)
1578                    })?
1579                    .map(|position| (closest_buffer, position));
1580            }
1581        }
1582
1583        anyhow::Ok(jump_location)
1584    }
1585
1586    async fn send_raw_llm_request(
1587        request: open_ai::Request,
1588        client: Arc<Client>,
1589        llm_token: LlmApiToken,
1590        app_version: Version,
1591        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1592        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1593    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1594        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1595            http_client::Url::parse(&predict_edits_url)?
1596        } else {
1597            client
1598                .http_client()
1599                .build_zed_llm_url("/predict_edits/raw", &[])?
1600        };
1601
1602        #[cfg(feature = "cli-support")]
1603        let cache_key = if let Some(cache) = eval_cache {
1604            use collections::FxHasher;
1605            use std::hash::{Hash, Hasher};
1606
1607            let mut hasher = FxHasher::default();
1608            url.hash(&mut hasher);
1609            let request_str = serde_json::to_string_pretty(&request)?;
1610            request_str.hash(&mut hasher);
1611            let hash = hasher.finish();
1612
1613            let key = (eval_cache_kind, hash);
1614            if let Some(response_str) = cache.read(key) {
1615                return Ok((serde_json::from_str(&response_str)?, None));
1616            }
1617
1618            Some((cache, request_str, key))
1619        } else {
1620            None
1621        };
1622
1623        let (response, usage) = Self::send_api_request(
1624            |builder| {
1625                let req = builder
1626                    .uri(url.as_ref())
1627                    .body(serde_json::to_string(&request)?.into());
1628                Ok(req?)
1629            },
1630            client,
1631            llm_token,
1632            app_version,
1633        )
1634        .await?;
1635
1636        #[cfg(feature = "cli-support")]
1637        if let Some((cache, request, key)) = cache_key {
1638            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1639        }
1640
1641        Ok((response, usage))
1642    }
1643
1644    fn handle_api_response<T>(
1645        this: &WeakEntity<Self>,
1646        response: Result<(T, Option<EditPredictionUsage>)>,
1647        cx: &mut gpui::AsyncApp,
1648    ) -> Result<T> {
1649        match response {
1650            Ok((data, usage)) => {
1651                if let Some(usage) = usage {
1652                    this.update(cx, |this, cx| {
1653                        this.user_store.update(cx, |user_store, cx| {
1654                            user_store.update_edit_prediction_usage(usage, cx);
1655                        });
1656                    })
1657                    .ok();
1658                }
1659                Ok(data)
1660            }
1661            Err(err) => {
1662                if err.is::<ZedUpdateRequiredError>() {
1663                    cx.update(|cx| {
1664                        this.update(cx, |this, _cx| {
1665                            this.update_required = true;
1666                        })
1667                        .ok();
1668
1669                        let error_message: SharedString = err.to_string().into();
1670                        show_app_notification(
1671                            NotificationId::unique::<ZedUpdateRequiredError>(),
1672                            cx,
1673                            move |cx| {
1674                                cx.new(|cx| {
1675                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1676                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1677                                })
1678                            },
1679                        );
1680                    })
1681                    .ok();
1682                }
1683                Err(err)
1684            }
1685        }
1686    }
1687
1688    async fn send_api_request<Res>(
1689        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1690        client: Arc<Client>,
1691        llm_token: LlmApiToken,
1692        app_version: Version,
1693    ) -> Result<(Res, Option<EditPredictionUsage>)>
1694    where
1695        Res: DeserializeOwned,
1696    {
1697        let http_client = client.http_client();
1698        let mut token = llm_token.acquire(&client).await?;
1699        let mut did_retry = false;
1700
1701        loop {
1702            let request_builder = http_client::Request::builder().method(Method::POST);
1703
1704            let request = build(
1705                request_builder
1706                    .header("Content-Type", "application/json")
1707                    .header("Authorization", format!("Bearer {}", token))
1708                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1709            )?;
1710
1711            let mut response = http_client.send(request).await?;
1712
1713            if let Some(minimum_required_version) = response
1714                .headers()
1715                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1716                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1717            {
1718                anyhow::ensure!(
1719                    app_version >= minimum_required_version,
1720                    ZedUpdateRequiredError {
1721                        minimum_version: minimum_required_version
1722                    }
1723                );
1724            }
1725
1726            if response.status().is_success() {
1727                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1728
1729                let mut body = Vec::new();
1730                response.body_mut().read_to_end(&mut body).await?;
1731                return Ok((serde_json::from_slice(&body)?, usage));
1732            } else if !did_retry
1733                && response
1734                    .headers()
1735                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1736                    .is_some()
1737            {
1738                did_retry = true;
1739                token = llm_token.refresh(&client).await?;
1740            } else {
1741                let mut body = String::new();
1742                response.body_mut().read_to_string(&mut body).await?;
1743                anyhow::bail!(
1744                    "Request failed with status: {:?}\nBody: {}",
1745                    response.status(),
1746                    body
1747                );
1748            }
1749        }
1750    }
1751
1752    pub fn refresh_context(
1753        &mut self,
1754        project: &Entity<Project>,
1755        buffer: &Entity<language::Buffer>,
1756        cursor_position: language::Anchor,
1757        cx: &mut Context<Self>,
1758    ) {
1759        if self.use_context {
1760            self.get_or_init_project(project, cx)
1761                .context
1762                .update(cx, |store, cx| {
1763                    store.refresh(buffer.clone(), cursor_position, cx);
1764                });
1765        }
1766    }
1767
1768    #[cfg(feature = "cli-support")]
1769    pub fn set_context_for_buffer(
1770        &mut self,
1771        project: &Entity<Project>,
1772        related_files: Vec<RelatedFile>,
1773        cx: &mut Context<Self>,
1774    ) {
1775        self.get_or_init_project(project, cx)
1776            .context
1777            .update(cx, |store, _| {
1778                store.set_related_files(related_files);
1779            });
1780    }
1781
1782    fn is_file_open_source(
1783        &self,
1784        project: &Entity<Project>,
1785        file: &Arc<dyn File>,
1786        cx: &App,
1787    ) -> bool {
1788        if !file.is_local() || file.is_private() {
1789            return false;
1790        }
1791        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1792            return false;
1793        };
1794        project_state
1795            .license_detection_watchers
1796            .get(&file.worktree_id(cx))
1797            .as_ref()
1798            .is_some_and(|watcher| watcher.is_project_open_source())
1799    }
1800
1801    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1802        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1803    }
1804
1805    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1806        if !self.data_collection_choice.is_enabled() {
1807            return false;
1808        }
1809        events.iter().all(|event| {
1810            matches!(
1811                event.as_ref(),
1812                zeta_prompt::Event::BufferChange {
1813                    in_open_source_repo: true,
1814                    ..
1815                }
1816            )
1817        })
1818    }
1819
1820    fn load_data_collection_choice() -> DataCollectionChoice {
1821        let choice = KEY_VALUE_STORE
1822            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1823            .log_err()
1824            .flatten();
1825
1826        match choice.as_deref() {
1827            Some("true") => DataCollectionChoice::Enabled,
1828            Some("false") => DataCollectionChoice::Disabled,
1829            Some(_) => {
1830                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1831                DataCollectionChoice::NotAnswered
1832            }
1833            None => DataCollectionChoice::NotAnswered,
1834        }
1835    }
1836
1837    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1838        self.data_collection_choice = self.data_collection_choice.toggle();
1839        let new_choice = self.data_collection_choice;
1840        db::write_and_log(cx, move || {
1841            KEY_VALUE_STORE.write_kvp(
1842                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1843                new_choice.is_enabled().to_string(),
1844            )
1845        });
1846    }
1847
1848    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1849        self.shown_predictions.iter()
1850    }
1851
1852    pub fn shown_completions_len(&self) -> usize {
1853        self.shown_predictions.len()
1854    }
1855
1856    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1857        self.rated_predictions.contains(id)
1858    }
1859
1860    pub fn rate_prediction(
1861        &mut self,
1862        prediction: &EditPrediction,
1863        rating: EditPredictionRating,
1864        feedback: String,
1865        cx: &mut Context<Self>,
1866    ) {
1867        self.rated_predictions.insert(prediction.id.clone());
1868        telemetry::event!(
1869            "Edit Prediction Rated",
1870            rating,
1871            inputs = prediction.inputs,
1872            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1873            feedback
1874        );
1875        self.client.telemetry().flush_events().detach();
1876        cx.notify();
1877    }
1878
1879    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1880        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1881            && all_language_settings(None, cx).edit_predictions.use_context;
1882    }
1883}
1884
1885#[derive(Error, Debug)]
1886#[error(
1887    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1888)]
1889pub struct ZedUpdateRequiredError {
1890    minimum_version: Version,
1891}
1892
1893#[cfg(feature = "cli-support")]
1894pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1895
1896#[cfg(feature = "cli-support")]
1897#[derive(Debug, Clone, Copy, PartialEq)]
1898pub enum EvalCacheEntryKind {
1899    Context,
1900    Search,
1901    Prediction,
1902}
1903
1904#[cfg(feature = "cli-support")]
1905impl std::fmt::Display for EvalCacheEntryKind {
1906    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1907        match self {
1908            EvalCacheEntryKind::Search => write!(f, "search"),
1909            EvalCacheEntryKind::Context => write!(f, "context"),
1910            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1911        }
1912    }
1913}
1914
1915#[cfg(feature = "cli-support")]
1916pub trait EvalCache: Send + Sync {
1917    fn read(&self, key: EvalCacheKey) -> Option<String>;
1918    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1919}
1920
1921#[derive(Debug, Clone, Copy)]
1922pub enum DataCollectionChoice {
1923    NotAnswered,
1924    Enabled,
1925    Disabled,
1926}
1927
1928impl DataCollectionChoice {
1929    pub fn is_enabled(self) -> bool {
1930        match self {
1931            Self::Enabled => true,
1932            Self::NotAnswered | Self::Disabled => false,
1933        }
1934    }
1935
1936    pub fn is_answered(self) -> bool {
1937        match self {
1938            Self::Enabled | Self::Disabled => true,
1939            Self::NotAnswered => false,
1940        }
1941    }
1942
1943    #[must_use]
1944    pub fn toggle(&self) -> DataCollectionChoice {
1945        match self {
1946            Self::Enabled => Self::Disabled,
1947            Self::Disabled => Self::Enabled,
1948            Self::NotAnswered => Self::Enabled,
1949        }
1950    }
1951}
1952
1953impl From<bool> for DataCollectionChoice {
1954    fn from(value: bool) -> Self {
1955        match value {
1956            true => DataCollectionChoice::Enabled,
1957            false => DataCollectionChoice::Disabled,
1958        }
1959    }
1960}
1961
1962struct ZedPredictUpsell;
1963
1964impl Dismissable for ZedPredictUpsell {
1965    const KEY: &'static str = "dismissed-edit-predict-upsell";
1966
1967    fn dismissed() -> bool {
1968        // To make this backwards compatible with older versions of Zed, we
1969        // check if the user has seen the previous Edit Prediction Onboarding
1970        // before, by checking the data collection choice which was written to
1971        // the database once the user clicked on "Accept and Enable"
1972        if KEY_VALUE_STORE
1973            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1974            .log_err()
1975            .is_some_and(|s| s.is_some())
1976        {
1977            return true;
1978        }
1979
1980        KEY_VALUE_STORE
1981            .read_kvp(Self::KEY)
1982            .log_err()
1983            .is_some_and(|s| s.is_some())
1984    }
1985}
1986
1987pub fn should_show_upsell_modal() -> bool {
1988    !ZedPredictUpsell::dismissed()
1989}
1990
1991pub fn init(cx: &mut App) {
1992    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1993        workspace.register_action(
1994            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1995                ZedPredictModal::toggle(
1996                    workspace,
1997                    workspace.user_store().clone(),
1998                    workspace.client().clone(),
1999                    window,
2000                    cx,
2001                )
2002            },
2003        );
2004
2005        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2006            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2007                settings
2008                    .project
2009                    .all_languages
2010                    .features
2011                    .get_or_insert_default()
2012                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2013            });
2014        });
2015    })
2016    .detach();
2017}