edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use workspace::Workspace;
  38
  39use std::ops::Range;
  40use std::path::Path;
  41use std::rc::Rc;
  42use std::str::FromStr as _;
  43use std::sync::{Arc, LazyLock};
  44use std::time::{Duration, Instant};
  45use std::{env, mem};
  46use thiserror::Error;
  47use util::{RangeExt as _, ResultExt as _};
  48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  49
  50mod cursor_excerpt;
  51mod license_detection;
  52pub mod mercury;
  53mod onboarding_modal;
  54pub mod open_ai_response;
  55mod prediction;
  56pub mod sweep_ai;
  57
  58#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
  59pub mod udiff;
  60
  61mod zed_edit_prediction_delegate;
  62pub mod zeta1;
  63pub mod zeta2;
  64
  65#[cfg(test)]
  66mod edit_prediction_tests;
  67
  68use crate::license_detection::LicenseDetectionWatcher;
  69use crate::mercury::Mercury;
  70use crate::onboarding_modal::ZedPredictModal;
  71pub use crate::prediction::EditPrediction;
  72pub use crate::prediction::EditPredictionId;
  73use crate::prediction::EditPredictionResult;
  74pub use crate::sweep_ai::SweepAi;
  75pub use telemetry_events::EditPredictionRating;
  76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  77
  78actions!(
  79    edit_prediction,
  80    [
  81        /// Resets the edit prediction onboarding state.
  82        ResetOnboarding,
  83        /// Clears the edit prediction history.
  84        ClearHistory,
  85    ]
  86);
  87
  88/// Maximum number of events to track.
  89const EVENT_COUNT_MAX: usize = 6;
  90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  93
  94pub struct SweepFeatureFlag;
  95
  96impl FeatureFlag for SweepFeatureFlag {
  97    const NAME: &str = "sweep-ai";
  98}
  99
 100pub struct MercuryFeatureFlag;
 101
 102impl FeatureFlag for MercuryFeatureFlag {
 103    const NAME: &str = "mercury";
 104}
 105
 106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 107    context: EditPredictionExcerptOptions {
 108        max_bytes: 512,
 109        min_bytes: 128,
 110        target_before_cursor_over_total_bytes: 0.5,
 111    },
 112    prompt_format: PromptFormat::DEFAULT,
 113};
 114
 115static USE_OLLAMA: LazyLock<bool> =
 116    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 117
 118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 119    match env::var("ZED_ZETA2_MODEL").as_deref() {
 120        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 121        Ok(model) => model,
 122        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 123        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 124    }
 125    .to_string()
 126});
 127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 128    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 129        if *USE_OLLAMA {
 130            Some("http://localhost:11434/v1/chat/completions".into())
 131        } else {
 132            None
 133        }
 134    })
 135});
 136
 137pub struct Zeta2FeatureFlag;
 138
 139impl FeatureFlag for Zeta2FeatureFlag {
 140    const NAME: &'static str = "zeta2";
 141
 142    fn enabled_for_staff() -> bool {
 143        true
 144    }
 145}
 146
 147#[derive(Clone)]
 148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 149
 150impl Global for EditPredictionStoreGlobal {}
 151
 152pub struct EditPredictionStore {
 153    client: Arc<Client>,
 154    user_store: Entity<UserStore>,
 155    llm_token: LlmApiToken,
 156    _llm_token_subscription: Subscription,
 157    projects: HashMap<EntityId, ProjectState>,
 158    use_context: bool,
 159    options: ZetaOptions,
 160    update_required: bool,
 161    #[cfg(feature = "cli-support")]
 162    eval_cache: Option<Arc<dyn EvalCache>>,
 163    edit_prediction_model: EditPredictionModel,
 164    pub sweep_ai: SweepAi,
 165    pub mercury: Mercury,
 166    data_collection_choice: DataCollectionChoice,
 167    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 168    shown_predictions: VecDeque<EditPrediction>,
 169    rated_predictions: HashSet<EditPredictionId>,
 170}
 171
 172#[derive(Copy, Clone, Default, PartialEq, Eq)]
 173pub enum EditPredictionModel {
 174    #[default]
 175    Zeta1,
 176    Zeta2,
 177    Sweep,
 178    Mercury,
 179}
 180
 181pub struct EditPredictionModelInput {
 182    project: Entity<Project>,
 183    buffer: Entity<Buffer>,
 184    snapshot: BufferSnapshot,
 185    position: Anchor,
 186    events: Vec<Arc<zeta_prompt::Event>>,
 187    related_files: Arc<[RelatedFile]>,
 188    recent_paths: VecDeque<ProjectPath>,
 189    trigger: PredictEditsRequestTrigger,
 190    diagnostic_search_range: Range<Point>,
 191    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 192}
 193
 194#[derive(Debug, Clone, PartialEq)]
 195pub struct ZetaOptions {
 196    pub context: EditPredictionExcerptOptions,
 197    pub prompt_format: predict_edits_v3::PromptFormat,
 198}
 199
 200#[derive(Debug)]
 201pub enum DebugEvent {
 202    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 203    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 204    EditPredictionStarted(EditPredictionStartedDebugEvent),
 205    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 206}
 207
 208#[derive(Debug)]
 209pub struct ContextRetrievalStartedDebugEvent {
 210    pub project_entity_id: EntityId,
 211    pub timestamp: Instant,
 212    pub search_prompt: String,
 213}
 214
 215#[derive(Debug)]
 216pub struct ContextRetrievalFinishedDebugEvent {
 217    pub project_entity_id: EntityId,
 218    pub timestamp: Instant,
 219    pub metadata: Vec<(&'static str, SharedString)>,
 220}
 221
 222#[derive(Debug)]
 223pub struct EditPredictionStartedDebugEvent {
 224    pub buffer: WeakEntity<Buffer>,
 225    pub position: Anchor,
 226    pub prompt: Option<String>,
 227}
 228
 229#[derive(Debug)]
 230pub struct EditPredictionFinishedDebugEvent {
 231    pub buffer: WeakEntity<Buffer>,
 232    pub position: Anchor,
 233    pub model_output: Option<String>,
 234}
 235
 236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 237
 238struct ProjectState {
 239    events: VecDeque<Arc<zeta_prompt::Event>>,
 240    last_event: Option<LastEvent>,
 241    recent_paths: VecDeque<ProjectPath>,
 242    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 243    current_prediction: Option<CurrentEditPrediction>,
 244    next_pending_prediction_id: usize,
 245    pending_predictions: ArrayVec<PendingPrediction, 2>,
 246    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 247    last_prediction_refresh: Option<(EntityId, Instant)>,
 248    cancelled_predictions: HashSet<usize>,
 249    context: Entity<RelatedExcerptStore>,
 250    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 251    _subscription: gpui::Subscription,
 252}
 253
 254impl ProjectState {
 255    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 256        self.events
 257            .iter()
 258            .cloned()
 259            .chain(
 260                self.last_event
 261                    .as_ref()
 262                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 263            )
 264            .collect()
 265    }
 266
 267    fn cancel_pending_prediction(
 268        &mut self,
 269        pending_prediction: PendingPrediction,
 270        cx: &mut Context<EditPredictionStore>,
 271    ) {
 272        self.cancelled_predictions.insert(pending_prediction.id);
 273
 274        cx.spawn(async move |this, cx| {
 275            let Some(prediction_id) = pending_prediction.task.await else {
 276                return;
 277            };
 278
 279            this.update(cx, |this, _cx| {
 280                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 281            })
 282            .ok();
 283        })
 284        .detach()
 285    }
 286
 287    fn active_buffer(
 288        &self,
 289        project: &Entity<Project>,
 290        cx: &App,
 291    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 292        let project = project.read(cx);
 293        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 294        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 295        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 296        Some((active_buffer, registered_buffer.last_position))
 297    }
 298}
 299
 300#[derive(Debug, Clone)]
 301struct CurrentEditPrediction {
 302    pub requested_by: PredictionRequestedBy,
 303    pub prediction: EditPrediction,
 304    pub was_shown: bool,
 305}
 306
 307impl CurrentEditPrediction {
 308    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 309        let Some(new_edits) = self
 310            .prediction
 311            .interpolate(&self.prediction.buffer.read(cx))
 312        else {
 313            return false;
 314        };
 315
 316        if self.prediction.buffer != old_prediction.prediction.buffer {
 317            return true;
 318        }
 319
 320        let Some(old_edits) = old_prediction
 321            .prediction
 322            .interpolate(&old_prediction.prediction.buffer.read(cx))
 323        else {
 324            return true;
 325        };
 326
 327        let requested_by_buffer_id = self.requested_by.buffer_id();
 328
 329        // This reduces the occurrence of UI thrash from replacing edits
 330        //
 331        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 332        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 333            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 334            && old_edits.len() == 1
 335            && new_edits.len() == 1
 336        {
 337            let (old_range, old_text) = &old_edits[0];
 338            let (new_range, new_text) = &new_edits[0];
 339            new_range == old_range && new_text.starts_with(old_text.as_ref())
 340        } else {
 341            true
 342        }
 343    }
 344}
 345
 346#[derive(Debug, Clone)]
 347enum PredictionRequestedBy {
 348    DiagnosticsUpdate,
 349    Buffer(EntityId),
 350}
 351
 352impl PredictionRequestedBy {
 353    pub fn buffer_id(&self) -> Option<EntityId> {
 354        match self {
 355            PredictionRequestedBy::DiagnosticsUpdate => None,
 356            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 357        }
 358    }
 359}
 360
 361#[derive(Debug)]
 362struct PendingPrediction {
 363    id: usize,
 364    task: Task<Option<EditPredictionId>>,
 365}
 366
 367/// A prediction from the perspective of a buffer.
 368#[derive(Debug)]
 369enum BufferEditPrediction<'a> {
 370    Local { prediction: &'a EditPrediction },
 371    Jump { prediction: &'a EditPrediction },
 372}
 373
 374#[cfg(test)]
 375impl std::ops::Deref for BufferEditPrediction<'_> {
 376    type Target = EditPrediction;
 377
 378    fn deref(&self) -> &Self::Target {
 379        match self {
 380            BufferEditPrediction::Local { prediction } => prediction,
 381            BufferEditPrediction::Jump { prediction } => prediction,
 382        }
 383    }
 384}
 385
 386struct RegisteredBuffer {
 387    snapshot: BufferSnapshot,
 388    last_position: Option<Anchor>,
 389    _subscriptions: [gpui::Subscription; 2],
 390}
 391
 392struct LastEvent {
 393    old_snapshot: BufferSnapshot,
 394    new_snapshot: BufferSnapshot,
 395    end_edit_anchor: Option<Anchor>,
 396}
 397
 398impl LastEvent {
 399    pub fn finalize(
 400        &self,
 401        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 402        cx: &App,
 403    ) -> Option<Arc<zeta_prompt::Event>> {
 404        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 405        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 406
 407        let file = self.new_snapshot.file();
 408        let old_file = self.old_snapshot.file();
 409
 410        let in_open_source_repo = [file, old_file].iter().all(|file| {
 411            file.is_some_and(|file| {
 412                license_detection_watchers
 413                    .get(&file.worktree_id(cx))
 414                    .is_some_and(|watcher| watcher.is_project_open_source())
 415            })
 416        });
 417
 418        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 419
 420        if path == old_path && diff.is_empty() {
 421            None
 422        } else {
 423            Some(Arc::new(zeta_prompt::Event::BufferChange {
 424                old_path,
 425                path,
 426                diff,
 427                in_open_source_repo,
 428                // TODO: Actually detect if this edit was predicted or not
 429                predicted: false,
 430            }))
 431        }
 432    }
 433}
 434
 435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 436    if let Some(file) = snapshot.file() {
 437        file.full_path(cx).into()
 438    } else {
 439        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 440    }
 441}
 442
 443impl EditPredictionStore {
 444    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 445        cx.try_global::<EditPredictionStoreGlobal>()
 446            .map(|global| global.0.clone())
 447    }
 448
 449    pub fn global(
 450        client: &Arc<Client>,
 451        user_store: &Entity<UserStore>,
 452        cx: &mut App,
 453    ) -> Entity<Self> {
 454        cx.try_global::<EditPredictionStoreGlobal>()
 455            .map(|global| global.0.clone())
 456            .unwrap_or_else(|| {
 457                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 458                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 459                ep_store
 460            })
 461    }
 462
 463    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 464        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 465        let data_collection_choice = Self::load_data_collection_choice();
 466
 467        let llm_token = LlmApiToken::default();
 468
 469        let (reject_tx, reject_rx) = mpsc::unbounded();
 470        cx.background_spawn({
 471            let client = client.clone();
 472            let llm_token = llm_token.clone();
 473            let app_version = AppVersion::global(cx);
 474            let background_executor = cx.background_executor().clone();
 475            async move {
 476                Self::handle_rejected_predictions(
 477                    reject_rx,
 478                    client,
 479                    llm_token,
 480                    app_version,
 481                    background_executor,
 482                )
 483                .await
 484            }
 485        })
 486        .detach();
 487
 488        let mut this = Self {
 489            projects: HashMap::default(),
 490            client,
 491            user_store,
 492            options: DEFAULT_OPTIONS,
 493            use_context: false,
 494            llm_token,
 495            _llm_token_subscription: cx.subscribe(
 496                &refresh_llm_token_listener,
 497                |this, _listener, _event, cx| {
 498                    let client = this.client.clone();
 499                    let llm_token = this.llm_token.clone();
 500                    cx.spawn(async move |_this, _cx| {
 501                        llm_token.refresh(&client).await?;
 502                        anyhow::Ok(())
 503                    })
 504                    .detach_and_log_err(cx);
 505                },
 506            ),
 507            update_required: false,
 508            #[cfg(feature = "cli-support")]
 509            eval_cache: None,
 510            edit_prediction_model: EditPredictionModel::Zeta2,
 511            sweep_ai: SweepAi::new(cx),
 512            mercury: Mercury::new(cx),
 513            data_collection_choice,
 514            reject_predictions_tx: reject_tx,
 515            rated_predictions: Default::default(),
 516            shown_predictions: Default::default(),
 517        };
 518
 519        this.configure_context_retrieval(cx);
 520        let weak_this = cx.weak_entity();
 521        cx.on_flags_ready(move |_, cx| {
 522            weak_this
 523                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 524                .ok();
 525        })
 526        .detach();
 527        cx.observe_global::<SettingsStore>(|this, cx| {
 528            this.configure_context_retrieval(cx);
 529        })
 530        .detach();
 531
 532        this
 533    }
 534
 535    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 536        self.edit_prediction_model = model;
 537    }
 538
 539    pub fn has_sweep_api_token(&self) -> bool {
 540        self.sweep_ai
 541            .api_token
 542            .clone()
 543            .now_or_never()
 544            .flatten()
 545            .is_some()
 546    }
 547
 548    pub fn has_mercury_api_token(&self) -> bool {
 549        self.mercury
 550            .api_token
 551            .clone()
 552            .now_or_never()
 553            .flatten()
 554            .is_some()
 555    }
 556
 557    #[cfg(feature = "cli-support")]
 558    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 559        self.eval_cache = Some(cache);
 560    }
 561
 562    pub fn options(&self) -> &ZetaOptions {
 563        &self.options
 564    }
 565
 566    pub fn set_options(&mut self, options: ZetaOptions) {
 567        self.options = options;
 568    }
 569
 570    pub fn set_use_context(&mut self, use_context: bool) {
 571        self.use_context = use_context;
 572    }
 573
 574    pub fn clear_history(&mut self) {
 575        for project_state in self.projects.values_mut() {
 576            project_state.events.clear();
 577        }
 578    }
 579
 580    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 581        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 582            project_state.events.clear();
 583        }
 584    }
 585
 586    pub fn edit_history_for_project(
 587        &self,
 588        project: &Entity<Project>,
 589        cx: &App,
 590    ) -> Vec<Arc<zeta_prompt::Event>> {
 591        self.projects
 592            .get(&project.entity_id())
 593            .map(|project_state| project_state.events(cx))
 594            .unwrap_or_default()
 595    }
 596
 597    pub fn context_for_project<'a>(
 598        &'a self,
 599        project: &Entity<Project>,
 600        cx: &'a App,
 601    ) -> Arc<[RelatedFile]> {
 602        self.projects
 603            .get(&project.entity_id())
 604            .map(|project| project.context.read(cx).related_files())
 605            .unwrap_or_else(|| vec![].into())
 606    }
 607
 608    pub fn context_for_project_with_buffers<'a>(
 609        &'a self,
 610        project: &Entity<Project>,
 611        cx: &'a App,
 612    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 613        self.projects
 614            .get(&project.entity_id())
 615            .map(|project| project.context.read(cx).related_files_with_buffers())
 616    }
 617
 618    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 619        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 620            self.user_store.read(cx).edit_prediction_usage()
 621        } else {
 622            None
 623        }
 624    }
 625
 626    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 627        self.get_or_init_project(project, cx);
 628    }
 629
 630    pub fn register_buffer(
 631        &mut self,
 632        buffer: &Entity<Buffer>,
 633        project: &Entity<Project>,
 634        cx: &mut Context<Self>,
 635    ) {
 636        let project_state = self.get_or_init_project(project, cx);
 637        Self::register_buffer_impl(project_state, buffer, project, cx);
 638    }
 639
 640    fn get_or_init_project(
 641        &mut self,
 642        project: &Entity<Project>,
 643        cx: &mut Context<Self>,
 644    ) -> &mut ProjectState {
 645        let entity_id = project.entity_id();
 646        self.projects
 647            .entry(entity_id)
 648            .or_insert_with(|| ProjectState {
 649                context: {
 650                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 651                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 652                        this.handle_excerpt_store_event(entity_id, event);
 653                    })
 654                    .detach();
 655                    related_excerpt_store
 656                },
 657                events: VecDeque::new(),
 658                last_event: None,
 659                recent_paths: VecDeque::new(),
 660                debug_tx: None,
 661                registered_buffers: HashMap::default(),
 662                current_prediction: None,
 663                cancelled_predictions: HashSet::default(),
 664                pending_predictions: ArrayVec::new(),
 665                next_pending_prediction_id: 0,
 666                last_prediction_refresh: None,
 667                license_detection_watchers: HashMap::default(),
 668                _subscription: cx.subscribe(&project, Self::handle_project_event),
 669            })
 670    }
 671
 672    pub fn remove_project(&mut self, project: &Entity<Project>) {
 673        self.projects.remove(&project.entity_id());
 674    }
 675
 676    fn handle_excerpt_store_event(
 677        &mut self,
 678        project_entity_id: EntityId,
 679        event: &RelatedExcerptStoreEvent,
 680    ) {
 681        if let Some(project_state) = self.projects.get(&project_entity_id) {
 682            if let Some(debug_tx) = project_state.debug_tx.clone() {
 683                match event {
 684                    RelatedExcerptStoreEvent::StartedRefresh => {
 685                        debug_tx
 686                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 687                                ContextRetrievalStartedDebugEvent {
 688                                    project_entity_id: project_entity_id,
 689                                    timestamp: Instant::now(),
 690                                    search_prompt: String::new(),
 691                                },
 692                            ))
 693                            .ok();
 694                    }
 695                    RelatedExcerptStoreEvent::FinishedRefresh {
 696                        cache_hit_count,
 697                        cache_miss_count,
 698                        mean_definition_latency,
 699                        max_definition_latency,
 700                    } => {
 701                        debug_tx
 702                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 703                                ContextRetrievalFinishedDebugEvent {
 704                                    project_entity_id: project_entity_id,
 705                                    timestamp: Instant::now(),
 706                                    metadata: vec![
 707                                        (
 708                                            "Cache Hits",
 709                                            format!(
 710                                                "{}/{}",
 711                                                cache_hit_count,
 712                                                cache_hit_count + cache_miss_count
 713                                            )
 714                                            .into(),
 715                                        ),
 716                                        (
 717                                            "Max LSP Time",
 718                                            format!("{} ms", max_definition_latency.as_millis())
 719                                                .into(),
 720                                        ),
 721                                        (
 722                                            "Mean LSP Time",
 723                                            format!("{} ms", mean_definition_latency.as_millis())
 724                                                .into(),
 725                                        ),
 726                                    ],
 727                                },
 728                            ))
 729                            .ok();
 730                    }
 731                }
 732            }
 733        }
 734    }
 735
 736    pub fn debug_info(
 737        &mut self,
 738        project: &Entity<Project>,
 739        cx: &mut Context<Self>,
 740    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 741        let project_state = self.get_or_init_project(project, cx);
 742        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 743        project_state.debug_tx = Some(debug_watch_tx);
 744        debug_watch_rx
 745    }
 746
 747    fn handle_project_event(
 748        &mut self,
 749        project: Entity<Project>,
 750        event: &project::Event,
 751        cx: &mut Context<Self>,
 752    ) {
 753        // TODO [zeta2] init with recent paths
 754        match event {
 755            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 756                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 757                    return;
 758                };
 759                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 760                if let Some(path) = path {
 761                    if let Some(ix) = project_state
 762                        .recent_paths
 763                        .iter()
 764                        .position(|probe| probe == &path)
 765                    {
 766                        project_state.recent_paths.remove(ix);
 767                    }
 768                    project_state.recent_paths.push_front(path);
 769                }
 770            }
 771            project::Event::DiagnosticsUpdated { .. } => {
 772                if cx.has_flag::<Zeta2FeatureFlag>() {
 773                    self.refresh_prediction_from_diagnostics(project, cx);
 774                }
 775            }
 776            _ => (),
 777        }
 778    }
 779
 780    fn register_buffer_impl<'a>(
 781        project_state: &'a mut ProjectState,
 782        buffer: &Entity<Buffer>,
 783        project: &Entity<Project>,
 784        cx: &mut Context<Self>,
 785    ) -> &'a mut RegisteredBuffer {
 786        let buffer_id = buffer.entity_id();
 787
 788        if let Some(file) = buffer.read(cx).file() {
 789            let worktree_id = file.worktree_id(cx);
 790            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 791                project_state
 792                    .license_detection_watchers
 793                    .entry(worktree_id)
 794                    .or_insert_with(|| {
 795                        let project_entity_id = project.entity_id();
 796                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 797                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 798                            else {
 799                                return;
 800                            };
 801                            project_state
 802                                .license_detection_watchers
 803                                .remove(&worktree_id);
 804                        })
 805                        .detach();
 806                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 807                    });
 808            }
 809        }
 810
 811        match project_state.registered_buffers.entry(buffer_id) {
 812            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 813            hash_map::Entry::Vacant(entry) => {
 814                let snapshot = buffer.read(cx).snapshot();
 815                let project_entity_id = project.entity_id();
 816                entry.insert(RegisteredBuffer {
 817                    snapshot,
 818                    last_position: None,
 819                    _subscriptions: [
 820                        cx.subscribe(buffer, {
 821                            let project = project.downgrade();
 822                            move |this, buffer, event, cx| {
 823                                if let language::BufferEvent::Edited = event
 824                                    && let Some(project) = project.upgrade()
 825                                {
 826                                    this.report_changes_for_buffer(&buffer, &project, cx);
 827                                }
 828                            }
 829                        }),
 830                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 831                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 832                            else {
 833                                return;
 834                            };
 835                            project_state.registered_buffers.remove(&buffer_id);
 836                        }),
 837                    ],
 838                })
 839            }
 840        }
 841    }
 842
 843    fn report_changes_for_buffer(
 844        &mut self,
 845        buffer: &Entity<Buffer>,
 846        project: &Entity<Project>,
 847        cx: &mut Context<Self>,
 848    ) {
 849        let project_state = self.get_or_init_project(project, cx);
 850        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 851
 852        let new_snapshot = buffer.read(cx).snapshot();
 853        if new_snapshot.version == registered_buffer.snapshot.version {
 854            return;
 855        }
 856
 857        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 858        let end_edit_anchor = new_snapshot
 859            .anchored_edits_since::<Point>(&old_snapshot.version)
 860            .last()
 861            .map(|(_, range)| range.end);
 862        let events = &mut project_state.events;
 863
 864        if let Some(LastEvent {
 865            new_snapshot: last_new_snapshot,
 866            end_edit_anchor: last_end_edit_anchor,
 867            ..
 868        }) = project_state.last_event.as_mut()
 869        {
 870            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 871                == last_new_snapshot.remote_id()
 872                && old_snapshot.version == last_new_snapshot.version;
 873
 874            let should_coalesce = is_next_snapshot_of_same_buffer
 875                && end_edit_anchor
 876                    .as_ref()
 877                    .zip(last_end_edit_anchor.as_ref())
 878                    .is_some_and(|(a, b)| {
 879                        let a = a.to_point(&new_snapshot);
 880                        let b = b.to_point(&new_snapshot);
 881                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 882                    });
 883
 884            if should_coalesce {
 885                *last_end_edit_anchor = end_edit_anchor;
 886                *last_new_snapshot = new_snapshot;
 887                return;
 888            }
 889        }
 890
 891        if events.len() + 1 >= EVENT_COUNT_MAX {
 892            events.pop_front();
 893        }
 894
 895        if let Some(event) = project_state.last_event.take() {
 896            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 897        }
 898
 899        project_state.last_event = Some(LastEvent {
 900            old_snapshot,
 901            new_snapshot,
 902            end_edit_anchor,
 903        });
 904    }
 905
 906    fn prediction_at(
 907        &mut self,
 908        buffer: &Entity<Buffer>,
 909        position: Option<language::Anchor>,
 910        project: &Entity<Project>,
 911        cx: &App,
 912    ) -> Option<BufferEditPrediction<'_>> {
 913        let project_state = self.projects.get_mut(&project.entity_id())?;
 914        if let Some(position) = position
 915            && let Some(buffer) = project_state
 916                .registered_buffers
 917                .get_mut(&buffer.entity_id())
 918        {
 919            buffer.last_position = Some(position);
 920        }
 921
 922        let CurrentEditPrediction {
 923            requested_by,
 924            prediction,
 925            ..
 926        } = project_state.current_prediction.as_ref()?;
 927
 928        if prediction.targets_buffer(buffer.read(cx)) {
 929            Some(BufferEditPrediction::Local { prediction })
 930        } else {
 931            let show_jump = match requested_by {
 932                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 933                    requested_by_buffer_id == &buffer.entity_id()
 934                }
 935                PredictionRequestedBy::DiagnosticsUpdate => true,
 936            };
 937
 938            if show_jump {
 939                Some(BufferEditPrediction::Jump { prediction })
 940            } else {
 941                None
 942            }
 943        }
 944    }
 945
 946    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 947        match self.edit_prediction_model {
 948            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
 949            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
 950        }
 951
 952        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 953            return;
 954        };
 955
 956        let Some(prediction) = project_state.current_prediction.take() else {
 957            return;
 958        };
 959        let request_id = prediction.prediction.id.to_string();
 960        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 961            project_state.cancel_pending_prediction(pending_prediction, cx);
 962        }
 963
 964        let client = self.client.clone();
 965        let llm_token = self.llm_token.clone();
 966        let app_version = AppVersion::global(cx);
 967        cx.spawn(async move |this, cx| {
 968            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 969                http_client::Url::parse(&predict_edits_url)?
 970            } else {
 971                client
 972                    .http_client()
 973                    .build_zed_llm_url("/predict_edits/accept", &[])?
 974            };
 975
 976            let response = cx
 977                .background_spawn(Self::send_api_request::<()>(
 978                    move |builder| {
 979                        let req = builder.uri(url.as_ref()).body(
 980                            serde_json::to_string(&AcceptEditPredictionBody {
 981                                request_id: request_id.clone(),
 982                            })?
 983                            .into(),
 984                        );
 985                        Ok(req?)
 986                    },
 987                    client,
 988                    llm_token,
 989                    app_version,
 990                ))
 991                .await;
 992
 993            Self::handle_api_response(&this, response, cx)?;
 994            anyhow::Ok(())
 995        })
 996        .detach_and_log_err(cx);
 997    }
 998
 999    async fn handle_rejected_predictions(
1000        rx: UnboundedReceiver<EditPredictionRejection>,
1001        client: Arc<Client>,
1002        llm_token: LlmApiToken,
1003        app_version: Version,
1004        background_executor: BackgroundExecutor,
1005    ) {
1006        let mut rx = std::pin::pin!(rx.peekable());
1007        let mut batched = Vec::new();
1008
1009        while let Some(rejection) = rx.next().await {
1010            batched.push(rejection);
1011
1012            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1013                select_biased! {
1014                    next = rx.as_mut().peek().fuse() => {
1015                        if next.is_some() {
1016                            continue;
1017                        }
1018                    }
1019                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1020                }
1021            }
1022
1023            let url = client
1024                .http_client()
1025                .build_zed_llm_url("/predict_edits/reject", &[])
1026                .unwrap();
1027
1028            let flush_count = batched
1029                .len()
1030                // in case items have accumulated after failure
1031                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1032            let start = batched.len() - flush_count;
1033
1034            let body = RejectEditPredictionsBodyRef {
1035                rejections: &batched[start..],
1036            };
1037
1038            let result = Self::send_api_request::<()>(
1039                |builder| {
1040                    let req = builder
1041                        .uri(url.as_ref())
1042                        .body(serde_json::to_string(&body)?.into());
1043                    anyhow::Ok(req?)
1044                },
1045                client.clone(),
1046                llm_token.clone(),
1047                app_version.clone(),
1048            )
1049            .await;
1050
1051            if result.log_err().is_some() {
1052                batched.drain(start..);
1053            }
1054        }
1055    }
1056
1057    fn reject_current_prediction(
1058        &mut self,
1059        reason: EditPredictionRejectReason,
1060        project: &Entity<Project>,
1061    ) {
1062        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1063            project_state.pending_predictions.clear();
1064            if let Some(prediction) = project_state.current_prediction.take() {
1065                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1066            }
1067        };
1068    }
1069
1070    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1071        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1072            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1073                if !current_prediction.was_shown {
1074                    current_prediction.was_shown = true;
1075                    self.shown_predictions
1076                        .push_front(current_prediction.prediction.clone());
1077                    if self.shown_predictions.len() > 50 {
1078                        let completion = self.shown_predictions.pop_back().unwrap();
1079                        self.rated_predictions.remove(&completion.id);
1080                    }
1081                }
1082            }
1083        }
1084    }
1085
1086    fn reject_prediction(
1087        &mut self,
1088        prediction_id: EditPredictionId,
1089        reason: EditPredictionRejectReason,
1090        was_shown: bool,
1091    ) {
1092        match self.edit_prediction_model {
1093            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1094            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1095        }
1096
1097        self.reject_predictions_tx
1098            .unbounded_send(EditPredictionRejection {
1099                request_id: prediction_id.to_string(),
1100                reason,
1101                was_shown,
1102            })
1103            .log_err();
1104    }
1105
1106    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1107        self.projects
1108            .get(&project.entity_id())
1109            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1110    }
1111
1112    pub fn refresh_prediction_from_buffer(
1113        &mut self,
1114        project: Entity<Project>,
1115        buffer: Entity<Buffer>,
1116        position: language::Anchor,
1117        cx: &mut Context<Self>,
1118    ) {
1119        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1120            let Some(request_task) = this
1121                .update(cx, |this, cx| {
1122                    this.request_prediction(
1123                        &project,
1124                        &buffer,
1125                        position,
1126                        PredictEditsRequestTrigger::Other,
1127                        cx,
1128                    )
1129                })
1130                .log_err()
1131            else {
1132                return Task::ready(anyhow::Ok(None));
1133            };
1134
1135            cx.spawn(async move |_cx| {
1136                request_task.await.map(|prediction_result| {
1137                    prediction_result.map(|prediction_result| {
1138                        (
1139                            prediction_result,
1140                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1141                        )
1142                    })
1143                })
1144            })
1145        })
1146    }
1147
1148    pub fn refresh_prediction_from_diagnostics(
1149        &mut self,
1150        project: Entity<Project>,
1151        cx: &mut Context<Self>,
1152    ) {
1153        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1154            return;
1155        };
1156
1157        // Prefer predictions from buffer
1158        if project_state.current_prediction.is_some() {
1159            return;
1160        };
1161
1162        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1163            let Some((active_buffer, snapshot, cursor_point)) = this
1164                .read_with(cx, |this, cx| {
1165                    let project_state = this.projects.get(&project.entity_id())?;
1166                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1167                    let snapshot = buffer.read(cx).snapshot();
1168
1169                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1170                        return None;
1171                    }
1172
1173                    let cursor_point = position
1174                        .map(|pos| pos.to_point(&snapshot))
1175                        .unwrap_or_default();
1176
1177                    Some((buffer, snapshot, cursor_point))
1178                })
1179                .log_err()
1180                .flatten()
1181            else {
1182                return Task::ready(anyhow::Ok(None));
1183            };
1184
1185            cx.spawn(async move |cx| {
1186                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1187                    active_buffer,
1188                    &snapshot,
1189                    Default::default(),
1190                    cursor_point,
1191                    &project,
1192                    cx,
1193                )
1194                .await?
1195                else {
1196                    return anyhow::Ok(None);
1197                };
1198
1199                let Some(prediction_result) = this
1200                    .update(cx, |this, cx| {
1201                        this.request_prediction(
1202                            &project,
1203                            &jump_buffer,
1204                            jump_position,
1205                            PredictEditsRequestTrigger::Diagnostics,
1206                            cx,
1207                        )
1208                    })?
1209                    .await?
1210                else {
1211                    return anyhow::Ok(None);
1212                };
1213
1214                this.update(cx, |this, cx| {
1215                    Some((
1216                        if this
1217                            .get_or_init_project(&project, cx)
1218                            .current_prediction
1219                            .is_none()
1220                        {
1221                            prediction_result
1222                        } else {
1223                            EditPredictionResult {
1224                                id: prediction_result.id,
1225                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1226                            }
1227                        },
1228                        PredictionRequestedBy::DiagnosticsUpdate,
1229                    ))
1230                })
1231            })
1232        });
1233    }
1234
1235    fn predictions_enabled_at(
1236        snapshot: &BufferSnapshot,
1237        position: Option<language::Anchor>,
1238        cx: &App,
1239    ) -> bool {
1240        let file = snapshot.file();
1241        let all_settings = all_language_settings(file, cx);
1242        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1243            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1244        {
1245            return false;
1246        }
1247
1248        if let Some(last_position) = position {
1249            let settings = snapshot.settings_at(last_position, cx);
1250
1251            if !settings.edit_predictions_disabled_in.is_empty()
1252                && let Some(scope) = snapshot.language_scope_at(last_position)
1253                && let Some(scope_name) = scope.override_name()
1254                && settings
1255                    .edit_predictions_disabled_in
1256                    .iter()
1257                    .any(|s| s == scope_name)
1258            {
1259                return false;
1260            }
1261        }
1262
1263        true
1264    }
1265
1266    #[cfg(not(test))]
1267    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1268    #[cfg(test)]
1269    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1270
1271    fn queue_prediction_refresh(
1272        &mut self,
1273        project: Entity<Project>,
1274        throttle_entity: EntityId,
1275        cx: &mut Context<Self>,
1276        do_refresh: impl FnOnce(
1277            WeakEntity<Self>,
1278            &mut AsyncApp,
1279        )
1280            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1281        + 'static,
1282    ) {
1283        let project_state = self.get_or_init_project(&project, cx);
1284        let pending_prediction_id = project_state.next_pending_prediction_id;
1285        project_state.next_pending_prediction_id += 1;
1286        let last_request = project_state.last_prediction_refresh;
1287
1288        let task = cx.spawn(async move |this, cx| {
1289            if let Some((last_entity, last_timestamp)) = last_request
1290                && throttle_entity == last_entity
1291                && let Some(timeout) =
1292                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1293            {
1294                cx.background_executor().timer(timeout).await;
1295            }
1296
1297            // If this task was cancelled before the throttle timeout expired,
1298            // do not perform a request.
1299            let mut is_cancelled = true;
1300            this.update(cx, |this, cx| {
1301                let project_state = this.get_or_init_project(&project, cx);
1302                if !project_state
1303                    .cancelled_predictions
1304                    .remove(&pending_prediction_id)
1305                {
1306                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1307                    is_cancelled = false;
1308                }
1309            })
1310            .ok();
1311            if is_cancelled {
1312                return None;
1313            }
1314
1315            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1316            let new_prediction_id = new_prediction_result
1317                .as_ref()
1318                .map(|(prediction, _)| prediction.id.clone());
1319
1320            // When a prediction completes, remove it from the pending list, and cancel
1321            // any pending predictions that were enqueued before it.
1322            this.update(cx, |this, cx| {
1323                let project_state = this.get_or_init_project(&project, cx);
1324
1325                let is_cancelled = project_state
1326                    .cancelled_predictions
1327                    .remove(&pending_prediction_id);
1328
1329                let new_current_prediction = if !is_cancelled
1330                    && let Some((prediction_result, requested_by)) = new_prediction_result
1331                {
1332                    match prediction_result.prediction {
1333                        Ok(prediction) => {
1334                            let new_prediction = CurrentEditPrediction {
1335                                requested_by,
1336                                prediction,
1337                                was_shown: false,
1338                            };
1339
1340                            if let Some(current_prediction) =
1341                                project_state.current_prediction.as_ref()
1342                            {
1343                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1344                                {
1345                                    this.reject_current_prediction(
1346                                        EditPredictionRejectReason::Replaced,
1347                                        &project,
1348                                    );
1349
1350                                    Some(new_prediction)
1351                                } else {
1352                                    this.reject_prediction(
1353                                        new_prediction.prediction.id,
1354                                        EditPredictionRejectReason::CurrentPreferred,
1355                                        false,
1356                                    );
1357                                    None
1358                                }
1359                            } else {
1360                                Some(new_prediction)
1361                            }
1362                        }
1363                        Err(reject_reason) => {
1364                            this.reject_prediction(prediction_result.id, reject_reason, false);
1365                            None
1366                        }
1367                    }
1368                } else {
1369                    None
1370                };
1371
1372                let project_state = this.get_or_init_project(&project, cx);
1373
1374                if let Some(new_prediction) = new_current_prediction {
1375                    project_state.current_prediction = Some(new_prediction);
1376                }
1377
1378                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1379                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1380                    if pending_prediction.id == pending_prediction_id {
1381                        pending_predictions.remove(ix);
1382                        for pending_prediction in pending_predictions.drain(0..ix) {
1383                            project_state.cancel_pending_prediction(pending_prediction, cx)
1384                        }
1385                        break;
1386                    }
1387                }
1388                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1389                cx.notify();
1390            })
1391            .ok();
1392
1393            new_prediction_id
1394        });
1395
1396        if project_state.pending_predictions.len() <= 1 {
1397            project_state.pending_predictions.push(PendingPrediction {
1398                id: pending_prediction_id,
1399                task,
1400            });
1401        } else if project_state.pending_predictions.len() == 2 {
1402            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1403            project_state.pending_predictions.push(PendingPrediction {
1404                id: pending_prediction_id,
1405                task,
1406            });
1407            project_state.cancel_pending_prediction(pending_prediction, cx);
1408        }
1409    }
1410
1411    pub fn request_prediction(
1412        &mut self,
1413        project: &Entity<Project>,
1414        active_buffer: &Entity<Buffer>,
1415        position: language::Anchor,
1416        trigger: PredictEditsRequestTrigger,
1417        cx: &mut Context<Self>,
1418    ) -> Task<Result<Option<EditPredictionResult>>> {
1419        self.request_prediction_internal(
1420            project.clone(),
1421            active_buffer.clone(),
1422            position,
1423            trigger,
1424            cx.has_flag::<Zeta2FeatureFlag>(),
1425            cx,
1426        )
1427    }
1428
1429    fn request_prediction_internal(
1430        &mut self,
1431        project: Entity<Project>,
1432        active_buffer: Entity<Buffer>,
1433        position: language::Anchor,
1434        trigger: PredictEditsRequestTrigger,
1435        allow_jump: bool,
1436        cx: &mut Context<Self>,
1437    ) -> Task<Result<Option<EditPredictionResult>>> {
1438        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1439
1440        self.get_or_init_project(&project, cx);
1441        let project_state = self.projects.get(&project.entity_id()).unwrap();
1442        let events = project_state.events(cx);
1443        let has_events = !events.is_empty();
1444        let debug_tx = project_state.debug_tx.clone();
1445
1446        let snapshot = active_buffer.read(cx).snapshot();
1447        let cursor_point = position.to_point(&snapshot);
1448        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1449        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1450        let diagnostic_search_range =
1451            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1452
1453        let related_files = if self.use_context {
1454            self.context_for_project(&project, cx)
1455        } else {
1456            Vec::new().into()
1457        };
1458
1459        let inputs = EditPredictionModelInput {
1460            project: project.clone(),
1461            buffer: active_buffer.clone(),
1462            snapshot: snapshot.clone(),
1463            position,
1464            events,
1465            related_files,
1466            recent_paths: project_state.recent_paths.clone(),
1467            trigger,
1468            diagnostic_search_range: diagnostic_search_range.clone(),
1469            debug_tx,
1470        };
1471
1472        let task = match self.edit_prediction_model {
1473            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1474            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1475            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1476            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1477        };
1478
1479        cx.spawn(async move |this, cx| {
1480            let prediction = task.await?;
1481
1482            if prediction.is_none() && allow_jump {
1483                let cursor_point = position.to_point(&snapshot);
1484                if has_events
1485                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1486                        active_buffer.clone(),
1487                        &snapshot,
1488                        diagnostic_search_range,
1489                        cursor_point,
1490                        &project,
1491                        cx,
1492                    )
1493                    .await?
1494                {
1495                    return this
1496                        .update(cx, |this, cx| {
1497                            this.request_prediction_internal(
1498                                project,
1499                                jump_buffer,
1500                                jump_position,
1501                                trigger,
1502                                false,
1503                                cx,
1504                            )
1505                        })?
1506                        .await;
1507                }
1508
1509                return anyhow::Ok(None);
1510            }
1511
1512            Ok(prediction)
1513        })
1514    }
1515
1516    async fn next_diagnostic_location(
1517        active_buffer: Entity<Buffer>,
1518        active_buffer_snapshot: &BufferSnapshot,
1519        active_buffer_diagnostic_search_range: Range<Point>,
1520        active_buffer_cursor_point: Point,
1521        project: &Entity<Project>,
1522        cx: &mut AsyncApp,
1523    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1524        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1525        let mut jump_location = active_buffer_snapshot
1526            .diagnostic_groups(None)
1527            .into_iter()
1528            .filter_map(|(_, group)| {
1529                let range = &group.entries[group.primary_ix]
1530                    .range
1531                    .to_point(&active_buffer_snapshot);
1532                if range.overlaps(&active_buffer_diagnostic_search_range) {
1533                    None
1534                } else {
1535                    Some(range.start)
1536                }
1537            })
1538            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1539            .map(|position| {
1540                (
1541                    active_buffer.clone(),
1542                    active_buffer_snapshot.anchor_before(position),
1543                )
1544            });
1545
1546        if jump_location.is_none() {
1547            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1548                let file = buffer.file()?;
1549
1550                Some(ProjectPath {
1551                    worktree_id: file.worktree_id(cx),
1552                    path: file.path().clone(),
1553                })
1554            })?;
1555
1556            let buffer_task = project.update(cx, |project, cx| {
1557                let (path, _, _) = project
1558                    .diagnostic_summaries(false, cx)
1559                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1560                    .max_by_key(|(path, _, _)| {
1561                        // find the buffer with errors that shares most parent directories
1562                        path.path
1563                            .components()
1564                            .zip(
1565                                active_buffer_path
1566                                    .as_ref()
1567                                    .map(|p| p.path.components())
1568                                    .unwrap_or_default(),
1569                            )
1570                            .take_while(|(a, b)| a == b)
1571                            .count()
1572                    })?;
1573
1574                Some(project.open_buffer(path, cx))
1575            })?;
1576
1577            if let Some(buffer_task) = buffer_task {
1578                let closest_buffer = buffer_task.await?;
1579
1580                jump_location = closest_buffer
1581                    .read_with(cx, |buffer, _cx| {
1582                        buffer
1583                            .buffer_diagnostics(None)
1584                            .into_iter()
1585                            .min_by_key(|entry| entry.diagnostic.severity)
1586                            .map(|entry| entry.range.start)
1587                    })?
1588                    .map(|position| (closest_buffer, position));
1589            }
1590        }
1591
1592        anyhow::Ok(jump_location)
1593    }
1594
1595    async fn send_raw_llm_request(
1596        request: open_ai::Request,
1597        client: Arc<Client>,
1598        llm_token: LlmApiToken,
1599        app_version: Version,
1600        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1601        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1602    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1603        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1604            http_client::Url::parse(&predict_edits_url)?
1605        } else {
1606            client
1607                .http_client()
1608                .build_zed_llm_url("/predict_edits/raw", &[])?
1609        };
1610
1611        #[cfg(feature = "cli-support")]
1612        let cache_key = if let Some(cache) = eval_cache {
1613            use collections::FxHasher;
1614            use std::hash::{Hash, Hasher};
1615
1616            let mut hasher = FxHasher::default();
1617            url.hash(&mut hasher);
1618            let request_str = serde_json::to_string_pretty(&request)?;
1619            request_str.hash(&mut hasher);
1620            let hash = hasher.finish();
1621
1622            let key = (eval_cache_kind, hash);
1623            if let Some(response_str) = cache.read(key) {
1624                return Ok((serde_json::from_str(&response_str)?, None));
1625            }
1626
1627            Some((cache, request_str, key))
1628        } else {
1629            None
1630        };
1631
1632        let (response, usage) = Self::send_api_request(
1633            |builder| {
1634                let req = builder
1635                    .uri(url.as_ref())
1636                    .body(serde_json::to_string(&request)?.into());
1637                Ok(req?)
1638            },
1639            client,
1640            llm_token,
1641            app_version,
1642        )
1643        .await?;
1644
1645        #[cfg(feature = "cli-support")]
1646        if let Some((cache, request, key)) = cache_key {
1647            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1648        }
1649
1650        Ok((response, usage))
1651    }
1652
1653    fn handle_api_response<T>(
1654        this: &WeakEntity<Self>,
1655        response: Result<(T, Option<EditPredictionUsage>)>,
1656        cx: &mut gpui::AsyncApp,
1657    ) -> Result<T> {
1658        match response {
1659            Ok((data, usage)) => {
1660                if let Some(usage) = usage {
1661                    this.update(cx, |this, cx| {
1662                        this.user_store.update(cx, |user_store, cx| {
1663                            user_store.update_edit_prediction_usage(usage, cx);
1664                        });
1665                    })
1666                    .ok();
1667                }
1668                Ok(data)
1669            }
1670            Err(err) => {
1671                if err.is::<ZedUpdateRequiredError>() {
1672                    cx.update(|cx| {
1673                        this.update(cx, |this, _cx| {
1674                            this.update_required = true;
1675                        })
1676                        .ok();
1677
1678                        let error_message: SharedString = err.to_string().into();
1679                        show_app_notification(
1680                            NotificationId::unique::<ZedUpdateRequiredError>(),
1681                            cx,
1682                            move |cx| {
1683                                cx.new(|cx| {
1684                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1685                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1686                                })
1687                            },
1688                        );
1689                    })
1690                    .ok();
1691                }
1692                Err(err)
1693            }
1694        }
1695    }
1696
1697    async fn send_api_request<Res>(
1698        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1699        client: Arc<Client>,
1700        llm_token: LlmApiToken,
1701        app_version: Version,
1702    ) -> Result<(Res, Option<EditPredictionUsage>)>
1703    where
1704        Res: DeserializeOwned,
1705    {
1706        let http_client = client.http_client();
1707        let mut token = llm_token.acquire(&client).await?;
1708        let mut did_retry = false;
1709
1710        loop {
1711            let request_builder = http_client::Request::builder().method(Method::POST);
1712
1713            let request = build(
1714                request_builder
1715                    .header("Content-Type", "application/json")
1716                    .header("Authorization", format!("Bearer {}", token))
1717                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1718            )?;
1719
1720            let mut response = http_client.send(request).await?;
1721
1722            if let Some(minimum_required_version) = response
1723                .headers()
1724                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1725                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1726            {
1727                anyhow::ensure!(
1728                    app_version >= minimum_required_version,
1729                    ZedUpdateRequiredError {
1730                        minimum_version: minimum_required_version
1731                    }
1732                );
1733            }
1734
1735            if response.status().is_success() {
1736                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1737
1738                let mut body = Vec::new();
1739                response.body_mut().read_to_end(&mut body).await?;
1740                return Ok((serde_json::from_slice(&body)?, usage));
1741            } else if !did_retry
1742                && response
1743                    .headers()
1744                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1745                    .is_some()
1746            {
1747                did_retry = true;
1748                token = llm_token.refresh(&client).await?;
1749            } else {
1750                let mut body = String::new();
1751                response.body_mut().read_to_string(&mut body).await?;
1752                anyhow::bail!(
1753                    "Request failed with status: {:?}\nBody: {}",
1754                    response.status(),
1755                    body
1756                );
1757            }
1758        }
1759    }
1760
1761    pub fn refresh_context(
1762        &mut self,
1763        project: &Entity<Project>,
1764        buffer: &Entity<language::Buffer>,
1765        cursor_position: language::Anchor,
1766        cx: &mut Context<Self>,
1767    ) {
1768        if self.use_context {
1769            self.get_or_init_project(project, cx)
1770                .context
1771                .update(cx, |store, cx| {
1772                    store.refresh(buffer.clone(), cursor_position, cx);
1773                });
1774        }
1775    }
1776
1777    #[cfg(feature = "cli-support")]
1778    pub fn set_context_for_buffer(
1779        &mut self,
1780        project: &Entity<Project>,
1781        related_files: Vec<RelatedFile>,
1782        cx: &mut Context<Self>,
1783    ) {
1784        self.get_or_init_project(project, cx)
1785            .context
1786            .update(cx, |store, _| {
1787                store.set_related_files(related_files);
1788            });
1789    }
1790
1791    fn is_file_open_source(
1792        &self,
1793        project: &Entity<Project>,
1794        file: &Arc<dyn File>,
1795        cx: &App,
1796    ) -> bool {
1797        if !file.is_local() || file.is_private() {
1798            return false;
1799        }
1800        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1801            return false;
1802        };
1803        project_state
1804            .license_detection_watchers
1805            .get(&file.worktree_id(cx))
1806            .as_ref()
1807            .is_some_and(|watcher| watcher.is_project_open_source())
1808    }
1809
1810    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1811        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1812    }
1813
1814    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1815        if !self.data_collection_choice.is_enabled() {
1816            return false;
1817        }
1818        events.iter().all(|event| {
1819            matches!(
1820                event.as_ref(),
1821                zeta_prompt::Event::BufferChange {
1822                    in_open_source_repo: true,
1823                    ..
1824                }
1825            )
1826        })
1827    }
1828
1829    fn load_data_collection_choice() -> DataCollectionChoice {
1830        let choice = KEY_VALUE_STORE
1831            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1832            .log_err()
1833            .flatten();
1834
1835        match choice.as_deref() {
1836            Some("true") => DataCollectionChoice::Enabled,
1837            Some("false") => DataCollectionChoice::Disabled,
1838            Some(_) => {
1839                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1840                DataCollectionChoice::NotAnswered
1841            }
1842            None => DataCollectionChoice::NotAnswered,
1843        }
1844    }
1845
1846    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1847        self.data_collection_choice = self.data_collection_choice.toggle();
1848        let new_choice = self.data_collection_choice;
1849        db::write_and_log(cx, move || {
1850            KEY_VALUE_STORE.write_kvp(
1851                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1852                new_choice.is_enabled().to_string(),
1853            )
1854        });
1855    }
1856
1857    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1858        self.shown_predictions.iter()
1859    }
1860
1861    pub fn shown_completions_len(&self) -> usize {
1862        self.shown_predictions.len()
1863    }
1864
1865    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1866        self.rated_predictions.contains(id)
1867    }
1868
1869    pub fn rate_prediction(
1870        &mut self,
1871        prediction: &EditPrediction,
1872        rating: EditPredictionRating,
1873        feedback: String,
1874        cx: &mut Context<Self>,
1875    ) {
1876        self.rated_predictions.insert(prediction.id.clone());
1877        telemetry::event!(
1878            "Edit Prediction Rated",
1879            rating,
1880            inputs = prediction.inputs,
1881            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1882            feedback
1883        );
1884        self.client.telemetry().flush_events().detach();
1885        cx.notify();
1886    }
1887
1888    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1889        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1890            && all_language_settings(None, cx).edit_predictions.use_context;
1891    }
1892}
1893
1894#[derive(Error, Debug)]
1895#[error(
1896    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1897)]
1898pub struct ZedUpdateRequiredError {
1899    minimum_version: Version,
1900}
1901
1902#[cfg(feature = "cli-support")]
1903pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1904
1905#[cfg(feature = "cli-support")]
1906#[derive(Debug, Clone, Copy, PartialEq)]
1907pub enum EvalCacheEntryKind {
1908    Context,
1909    Search,
1910    Prediction,
1911}
1912
1913#[cfg(feature = "cli-support")]
1914impl std::fmt::Display for EvalCacheEntryKind {
1915    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1916        match self {
1917            EvalCacheEntryKind::Search => write!(f, "search"),
1918            EvalCacheEntryKind::Context => write!(f, "context"),
1919            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1920        }
1921    }
1922}
1923
1924#[cfg(feature = "cli-support")]
1925pub trait EvalCache: Send + Sync {
1926    fn read(&self, key: EvalCacheKey) -> Option<String>;
1927    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1928}
1929
1930#[derive(Debug, Clone, Copy)]
1931pub enum DataCollectionChoice {
1932    NotAnswered,
1933    Enabled,
1934    Disabled,
1935}
1936
1937impl DataCollectionChoice {
1938    pub fn is_enabled(self) -> bool {
1939        match self {
1940            Self::Enabled => true,
1941            Self::NotAnswered | Self::Disabled => false,
1942        }
1943    }
1944
1945    pub fn is_answered(self) -> bool {
1946        match self {
1947            Self::Enabled | Self::Disabled => true,
1948            Self::NotAnswered => false,
1949        }
1950    }
1951
1952    #[must_use]
1953    pub fn toggle(&self) -> DataCollectionChoice {
1954        match self {
1955            Self::Enabled => Self::Disabled,
1956            Self::Disabled => Self::Enabled,
1957            Self::NotAnswered => Self::Enabled,
1958        }
1959    }
1960}
1961
1962impl From<bool> for DataCollectionChoice {
1963    fn from(value: bool) -> Self {
1964        match value {
1965            true => DataCollectionChoice::Enabled,
1966            false => DataCollectionChoice::Disabled,
1967        }
1968    }
1969}
1970
1971struct ZedPredictUpsell;
1972
1973impl Dismissable for ZedPredictUpsell {
1974    const KEY: &'static str = "dismissed-edit-predict-upsell";
1975
1976    fn dismissed() -> bool {
1977        // To make this backwards compatible with older versions of Zed, we
1978        // check if the user has seen the previous Edit Prediction Onboarding
1979        // before, by checking the data collection choice which was written to
1980        // the database once the user clicked on "Accept and Enable"
1981        if KEY_VALUE_STORE
1982            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1983            .log_err()
1984            .is_some_and(|s| s.is_some())
1985        {
1986            return true;
1987        }
1988
1989        KEY_VALUE_STORE
1990            .read_kvp(Self::KEY)
1991            .log_err()
1992            .is_some_and(|s| s.is_some())
1993    }
1994}
1995
1996pub fn should_show_upsell_modal() -> bool {
1997    !ZedPredictUpsell::dismissed()
1998}
1999
2000pub fn init(cx: &mut App) {
2001    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2002        workspace.register_action(
2003            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2004                ZedPredictModal::toggle(
2005                    workspace,
2006                    workspace.user_store().clone(),
2007                    workspace.client().clone(),
2008                    window,
2009                    cx,
2010                )
2011            },
2012        );
2013
2014        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2015            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2016                settings
2017                    .project
2018                    .all_languages
2019                    .features
2020                    .get_or_insert_default()
2021                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2022            });
2023        });
2024    })
2025    .detach();
2026}