edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use workspace::Workspace;
  38
  39use std::ops::Range;
  40use std::path::Path;
  41use std::rc::Rc;
  42use std::str::FromStr as _;
  43use std::sync::{Arc, LazyLock};
  44use std::time::{Duration, Instant};
  45use std::{env, mem};
  46use thiserror::Error;
  47use util::{RangeExt as _, ResultExt as _};
  48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  49
  50mod cursor_excerpt;
  51mod license_detection;
  52pub mod mercury;
  53mod onboarding_modal;
  54pub mod open_ai_response;
  55mod prediction;
  56pub mod sweep_ai;
  57
  58#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
  59pub mod udiff;
  60
  61mod zed_edit_prediction_delegate;
  62pub mod zeta1;
  63pub mod zeta2;
  64
  65#[cfg(test)]
  66mod edit_prediction_tests;
  67
  68use crate::license_detection::LicenseDetectionWatcher;
  69use crate::mercury::Mercury;
  70use crate::onboarding_modal::ZedPredictModal;
  71pub use crate::prediction::EditPrediction;
  72pub use crate::prediction::EditPredictionId;
  73use crate::prediction::EditPredictionResult;
  74pub use crate::sweep_ai::SweepAi;
  75pub use telemetry_events::EditPredictionRating;
  76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  77
  78actions!(
  79    edit_prediction,
  80    [
  81        /// Resets the edit prediction onboarding state.
  82        ResetOnboarding,
  83        /// Clears the edit prediction history.
  84        ClearHistory,
  85    ]
  86);
  87
  88/// Maximum number of events to track.
  89const EVENT_COUNT_MAX: usize = 6;
  90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  93
  94pub struct SweepFeatureFlag;
  95
  96impl FeatureFlag for SweepFeatureFlag {
  97    const NAME: &str = "sweep-ai";
  98}
  99
 100pub struct MercuryFeatureFlag;
 101
 102impl FeatureFlag for MercuryFeatureFlag {
 103    const NAME: &str = "mercury";
 104}
 105
 106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 107    context: EditPredictionExcerptOptions {
 108        max_bytes: 512,
 109        min_bytes: 128,
 110        target_before_cursor_over_total_bytes: 0.5,
 111    },
 112    prompt_format: PromptFormat::DEFAULT,
 113};
 114
 115static USE_OLLAMA: LazyLock<bool> =
 116    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 117
 118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 119    match env::var("ZED_ZETA2_MODEL").as_deref() {
 120        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 121        Ok(model) => model,
 122        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 123        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 124    }
 125    .to_string()
 126});
 127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 128    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 129        if *USE_OLLAMA {
 130            Some("http://localhost:11434/v1/chat/completions".into())
 131        } else {
 132            None
 133        }
 134    })
 135});
 136
 137pub struct Zeta2FeatureFlag;
 138
 139impl FeatureFlag for Zeta2FeatureFlag {
 140    const NAME: &'static str = "zeta2";
 141
 142    fn enabled_for_staff() -> bool {
 143        true
 144    }
 145}
 146
 147#[derive(Clone)]
 148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 149
 150impl Global for EditPredictionStoreGlobal {}
 151
 152pub struct EditPredictionStore {
 153    client: Arc<Client>,
 154    user_store: Entity<UserStore>,
 155    llm_token: LlmApiToken,
 156    _llm_token_subscription: Subscription,
 157    projects: HashMap<EntityId, ProjectState>,
 158    use_context: bool,
 159    options: ZetaOptions,
 160    update_required: bool,
 161    #[cfg(feature = "cli-support")]
 162    eval_cache: Option<Arc<dyn EvalCache>>,
 163    edit_prediction_model: EditPredictionModel,
 164    pub sweep_ai: SweepAi,
 165    pub mercury: Mercury,
 166    data_collection_choice: DataCollectionChoice,
 167    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 168    shown_predictions: VecDeque<EditPrediction>,
 169    rated_predictions: HashSet<EditPredictionId>,
 170}
 171
 172#[derive(Copy, Clone, Default, PartialEq, Eq)]
 173pub enum EditPredictionModel {
 174    #[default]
 175    Zeta1,
 176    Zeta2,
 177    Sweep,
 178    Mercury,
 179}
 180
 181pub struct EditPredictionModelInput {
 182    project: Entity<Project>,
 183    buffer: Entity<Buffer>,
 184    snapshot: BufferSnapshot,
 185    position: Anchor,
 186    events: Vec<Arc<zeta_prompt::Event>>,
 187    related_files: Arc<[RelatedFile]>,
 188    recent_paths: VecDeque<ProjectPath>,
 189    trigger: PredictEditsRequestTrigger,
 190    diagnostic_search_range: Range<Point>,
 191    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 192}
 193
 194#[derive(Debug, Clone, PartialEq)]
 195pub struct ZetaOptions {
 196    pub context: EditPredictionExcerptOptions,
 197    pub prompt_format: predict_edits_v3::PromptFormat,
 198}
 199
 200#[derive(Debug)]
 201pub enum DebugEvent {
 202    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 203    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 204    EditPredictionStarted(EditPredictionStartedDebugEvent),
 205    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 206}
 207
 208#[derive(Debug)]
 209pub struct ContextRetrievalStartedDebugEvent {
 210    pub project_entity_id: EntityId,
 211    pub timestamp: Instant,
 212    pub search_prompt: String,
 213}
 214
 215#[derive(Debug)]
 216pub struct ContextRetrievalFinishedDebugEvent {
 217    pub project_entity_id: EntityId,
 218    pub timestamp: Instant,
 219    pub metadata: Vec<(&'static str, SharedString)>,
 220}
 221
 222#[derive(Debug)]
 223pub struct EditPredictionStartedDebugEvent {
 224    pub buffer: WeakEntity<Buffer>,
 225    pub position: Anchor,
 226    pub prompt: Option<String>,
 227}
 228
 229#[derive(Debug)]
 230pub struct EditPredictionFinishedDebugEvent {
 231    pub buffer: WeakEntity<Buffer>,
 232    pub position: Anchor,
 233    pub model_output: Option<String>,
 234}
 235
 236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 237
 238struct ProjectState {
 239    events: VecDeque<Arc<zeta_prompt::Event>>,
 240    last_event: Option<LastEvent>,
 241    recent_paths: VecDeque<ProjectPath>,
 242    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 243    current_prediction: Option<CurrentEditPrediction>,
 244    next_pending_prediction_id: usize,
 245    pending_predictions: ArrayVec<PendingPrediction, 2>,
 246    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 247    last_prediction_refresh: Option<(EntityId, Instant)>,
 248    cancelled_predictions: HashSet<usize>,
 249    context: Entity<RelatedExcerptStore>,
 250    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 251    _subscription: gpui::Subscription,
 252}
 253
 254impl ProjectState {
 255    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 256        self.events
 257            .iter()
 258            .cloned()
 259            .chain(
 260                self.last_event
 261                    .as_ref()
 262                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 263            )
 264            .collect()
 265    }
 266
 267    fn cancel_pending_prediction(
 268        &mut self,
 269        pending_prediction: PendingPrediction,
 270        cx: &mut Context<EditPredictionStore>,
 271    ) {
 272        self.cancelled_predictions.insert(pending_prediction.id);
 273
 274        cx.spawn(async move |this, cx| {
 275            let Some(prediction_id) = pending_prediction.task.await else {
 276                return;
 277            };
 278
 279            this.update(cx, |this, _cx| {
 280                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 281            })
 282            .ok();
 283        })
 284        .detach()
 285    }
 286
 287    fn active_buffer(
 288        &self,
 289        project: &Entity<Project>,
 290        cx: &App,
 291    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 292        let project = project.read(cx);
 293        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 294        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 295        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 296        Some((active_buffer, registered_buffer.last_position))
 297    }
 298}
 299
 300#[derive(Debug, Clone)]
 301struct CurrentEditPrediction {
 302    pub requested_by: PredictionRequestedBy,
 303    pub prediction: EditPrediction,
 304    pub was_shown: bool,
 305}
 306
 307impl CurrentEditPrediction {
 308    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 309        let Some(new_edits) = self
 310            .prediction
 311            .interpolate(&self.prediction.buffer.read(cx))
 312        else {
 313            return false;
 314        };
 315
 316        if self.prediction.buffer != old_prediction.prediction.buffer {
 317            return true;
 318        }
 319
 320        let Some(old_edits) = old_prediction
 321            .prediction
 322            .interpolate(&old_prediction.prediction.buffer.read(cx))
 323        else {
 324            return true;
 325        };
 326
 327        let requested_by_buffer_id = self.requested_by.buffer_id();
 328
 329        // This reduces the occurrence of UI thrash from replacing edits
 330        //
 331        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 332        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 333            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 334            && old_edits.len() == 1
 335            && new_edits.len() == 1
 336        {
 337            let (old_range, old_text) = &old_edits[0];
 338            let (new_range, new_text) = &new_edits[0];
 339            new_range == old_range && new_text.starts_with(old_text.as_ref())
 340        } else {
 341            true
 342        }
 343    }
 344}
 345
 346#[derive(Debug, Clone)]
 347enum PredictionRequestedBy {
 348    DiagnosticsUpdate,
 349    Buffer(EntityId),
 350}
 351
 352impl PredictionRequestedBy {
 353    pub fn buffer_id(&self) -> Option<EntityId> {
 354        match self {
 355            PredictionRequestedBy::DiagnosticsUpdate => None,
 356            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 357        }
 358    }
 359}
 360
 361#[derive(Debug)]
 362struct PendingPrediction {
 363    id: usize,
 364    task: Task<Option<EditPredictionId>>,
 365}
 366
 367/// A prediction from the perspective of a buffer.
 368#[derive(Debug)]
 369enum BufferEditPrediction<'a> {
 370    Local { prediction: &'a EditPrediction },
 371    Jump { prediction: &'a EditPrediction },
 372}
 373
 374#[cfg(test)]
 375impl std::ops::Deref for BufferEditPrediction<'_> {
 376    type Target = EditPrediction;
 377
 378    fn deref(&self) -> &Self::Target {
 379        match self {
 380            BufferEditPrediction::Local { prediction } => prediction,
 381            BufferEditPrediction::Jump { prediction } => prediction,
 382        }
 383    }
 384}
 385
 386struct RegisteredBuffer {
 387    snapshot: BufferSnapshot,
 388    last_position: Option<Anchor>,
 389    _subscriptions: [gpui::Subscription; 2],
 390}
 391
 392struct LastEvent {
 393    old_snapshot: BufferSnapshot,
 394    new_snapshot: BufferSnapshot,
 395    end_edit_anchor: Option<Anchor>,
 396}
 397
 398impl LastEvent {
 399    pub fn finalize(
 400        &self,
 401        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 402        cx: &App,
 403    ) -> Option<Arc<zeta_prompt::Event>> {
 404        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 405        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 406
 407        let file = self.new_snapshot.file();
 408        let old_file = self.old_snapshot.file();
 409
 410        let in_open_source_repo = [file, old_file].iter().all(|file| {
 411            file.is_some_and(|file| {
 412                license_detection_watchers
 413                    .get(&file.worktree_id(cx))
 414                    .is_some_and(|watcher| watcher.is_project_open_source())
 415            })
 416        });
 417
 418        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 419
 420        if path == old_path && diff.is_empty() {
 421            None
 422        } else {
 423            Some(Arc::new(zeta_prompt::Event::BufferChange {
 424                old_path,
 425                path,
 426                diff,
 427                in_open_source_repo,
 428                // TODO: Actually detect if this edit was predicted or not
 429                predicted: false,
 430            }))
 431        }
 432    }
 433}
 434
 435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 436    if let Some(file) = snapshot.file() {
 437        file.full_path(cx).into()
 438    } else {
 439        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 440    }
 441}
 442
 443impl EditPredictionStore {
 444    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 445        cx.try_global::<EditPredictionStoreGlobal>()
 446            .map(|global| global.0.clone())
 447    }
 448
 449    pub fn global(
 450        client: &Arc<Client>,
 451        user_store: &Entity<UserStore>,
 452        cx: &mut App,
 453    ) -> Entity<Self> {
 454        cx.try_global::<EditPredictionStoreGlobal>()
 455            .map(|global| global.0.clone())
 456            .unwrap_or_else(|| {
 457                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 458                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 459                ep_store
 460            })
 461    }
 462
 463    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 464        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 465        let data_collection_choice = Self::load_data_collection_choice();
 466
 467        let llm_token = LlmApiToken::default();
 468
 469        let (reject_tx, reject_rx) = mpsc::unbounded();
 470        cx.background_spawn({
 471            let client = client.clone();
 472            let llm_token = llm_token.clone();
 473            let app_version = AppVersion::global(cx);
 474            let background_executor = cx.background_executor().clone();
 475            async move {
 476                Self::handle_rejected_predictions(
 477                    reject_rx,
 478                    client,
 479                    llm_token,
 480                    app_version,
 481                    background_executor,
 482                )
 483                .await
 484            }
 485        })
 486        .detach();
 487
 488        let mut this = Self {
 489            projects: HashMap::default(),
 490            client,
 491            user_store,
 492            options: DEFAULT_OPTIONS,
 493            use_context: false,
 494            llm_token,
 495            _llm_token_subscription: cx.subscribe(
 496                &refresh_llm_token_listener,
 497                |this, _listener, _event, cx| {
 498                    let client = this.client.clone();
 499                    let llm_token = this.llm_token.clone();
 500                    cx.spawn(async move |_this, _cx| {
 501                        llm_token.refresh(&client).await?;
 502                        anyhow::Ok(())
 503                    })
 504                    .detach_and_log_err(cx);
 505                },
 506            ),
 507            update_required: false,
 508            #[cfg(feature = "cli-support")]
 509            eval_cache: None,
 510            edit_prediction_model: EditPredictionModel::Zeta2,
 511            sweep_ai: SweepAi::new(cx),
 512            mercury: Mercury::new(cx),
 513            data_collection_choice,
 514            reject_predictions_tx: reject_tx,
 515            rated_predictions: Default::default(),
 516            shown_predictions: Default::default(),
 517        };
 518
 519        this.configure_context_retrieval(cx);
 520        let weak_this = cx.weak_entity();
 521        cx.on_flags_ready(move |_, cx| {
 522            weak_this
 523                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 524                .ok();
 525        })
 526        .detach();
 527        cx.observe_global::<SettingsStore>(|this, cx| {
 528            this.configure_context_retrieval(cx);
 529        })
 530        .detach();
 531
 532        this
 533    }
 534
 535    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 536        self.edit_prediction_model = model;
 537    }
 538
 539    pub fn has_sweep_api_token(&self) -> bool {
 540        self.sweep_ai
 541            .api_token
 542            .clone()
 543            .now_or_never()
 544            .flatten()
 545            .is_some()
 546    }
 547
 548    pub fn has_mercury_api_token(&self) -> bool {
 549        self.mercury
 550            .api_token
 551            .clone()
 552            .now_or_never()
 553            .flatten()
 554            .is_some()
 555    }
 556
 557    #[cfg(feature = "cli-support")]
 558    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 559        self.eval_cache = Some(cache);
 560    }
 561
 562    pub fn options(&self) -> &ZetaOptions {
 563        &self.options
 564    }
 565
 566    pub fn set_options(&mut self, options: ZetaOptions) {
 567        self.options = options;
 568    }
 569
 570    pub fn set_use_context(&mut self, use_context: bool) {
 571        self.use_context = use_context;
 572    }
 573
 574    pub fn clear_history(&mut self) {
 575        for project_state in self.projects.values_mut() {
 576            project_state.events.clear();
 577        }
 578    }
 579
 580    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 581        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 582            project_state.events.clear();
 583        }
 584    }
 585
 586    pub fn edit_history_for_project(
 587        &self,
 588        project: &Entity<Project>,
 589    ) -> Vec<Arc<zeta_prompt::Event>> {
 590        self.projects
 591            .get(&project.entity_id())
 592            .map(|project_state| project_state.events.iter().cloned().collect())
 593            .unwrap_or_default()
 594    }
 595
 596    pub fn context_for_project<'a>(
 597        &'a self,
 598        project: &Entity<Project>,
 599        cx: &'a App,
 600    ) -> Arc<[RelatedFile]> {
 601        self.projects
 602            .get(&project.entity_id())
 603            .map(|project| project.context.read(cx).related_files())
 604            .unwrap_or_else(|| vec![].into())
 605    }
 606
 607    pub fn context_for_project_with_buffers<'a>(
 608        &'a self,
 609        project: &Entity<Project>,
 610        cx: &'a App,
 611    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 612        self.projects
 613            .get(&project.entity_id())
 614            .map(|project| project.context.read(cx).related_files_with_buffers())
 615    }
 616
 617    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 618        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 619            self.user_store.read(cx).edit_prediction_usage()
 620        } else {
 621            None
 622        }
 623    }
 624
 625    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 626        self.get_or_init_project(project, cx);
 627    }
 628
 629    pub fn register_buffer(
 630        &mut self,
 631        buffer: &Entity<Buffer>,
 632        project: &Entity<Project>,
 633        cx: &mut Context<Self>,
 634    ) {
 635        let project_state = self.get_or_init_project(project, cx);
 636        Self::register_buffer_impl(project_state, buffer, project, cx);
 637    }
 638
 639    fn get_or_init_project(
 640        &mut self,
 641        project: &Entity<Project>,
 642        cx: &mut Context<Self>,
 643    ) -> &mut ProjectState {
 644        let entity_id = project.entity_id();
 645        self.projects
 646            .entry(entity_id)
 647            .or_insert_with(|| ProjectState {
 648                context: {
 649                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 650                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 651                        this.handle_excerpt_store_event(entity_id, event);
 652                    })
 653                    .detach();
 654                    related_excerpt_store
 655                },
 656                events: VecDeque::new(),
 657                last_event: None,
 658                recent_paths: VecDeque::new(),
 659                debug_tx: None,
 660                registered_buffers: HashMap::default(),
 661                current_prediction: None,
 662                cancelled_predictions: HashSet::default(),
 663                pending_predictions: ArrayVec::new(),
 664                next_pending_prediction_id: 0,
 665                last_prediction_refresh: None,
 666                license_detection_watchers: HashMap::default(),
 667                _subscription: cx.subscribe(&project, Self::handle_project_event),
 668            })
 669    }
 670
 671    pub fn remove_project(&mut self, project: &Entity<Project>) {
 672        self.projects.remove(&project.entity_id());
 673    }
 674
 675    fn handle_excerpt_store_event(
 676        &mut self,
 677        project_entity_id: EntityId,
 678        event: &RelatedExcerptStoreEvent,
 679    ) {
 680        if let Some(project_state) = self.projects.get(&project_entity_id) {
 681            if let Some(debug_tx) = project_state.debug_tx.clone() {
 682                match event {
 683                    RelatedExcerptStoreEvent::StartedRefresh => {
 684                        debug_tx
 685                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 686                                ContextRetrievalStartedDebugEvent {
 687                                    project_entity_id: project_entity_id,
 688                                    timestamp: Instant::now(),
 689                                    search_prompt: String::new(),
 690                                },
 691                            ))
 692                            .ok();
 693                    }
 694                    RelatedExcerptStoreEvent::FinishedRefresh {
 695                        cache_hit_count,
 696                        cache_miss_count,
 697                        mean_definition_latency,
 698                        max_definition_latency,
 699                    } => {
 700                        debug_tx
 701                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 702                                ContextRetrievalFinishedDebugEvent {
 703                                    project_entity_id: project_entity_id,
 704                                    timestamp: Instant::now(),
 705                                    metadata: vec![
 706                                        (
 707                                            "Cache Hits",
 708                                            format!(
 709                                                "{}/{}",
 710                                                cache_hit_count,
 711                                                cache_hit_count + cache_miss_count
 712                                            )
 713                                            .into(),
 714                                        ),
 715                                        (
 716                                            "Max LSP Time",
 717                                            format!("{} ms", max_definition_latency.as_millis())
 718                                                .into(),
 719                                        ),
 720                                        (
 721                                            "Mean LSP Time",
 722                                            format!("{} ms", mean_definition_latency.as_millis())
 723                                                .into(),
 724                                        ),
 725                                    ],
 726                                },
 727                            ))
 728                            .ok();
 729                    }
 730                }
 731            }
 732        }
 733    }
 734
 735    pub fn debug_info(
 736        &mut self,
 737        project: &Entity<Project>,
 738        cx: &mut Context<Self>,
 739    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 740        let project_state = self.get_or_init_project(project, cx);
 741        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 742        project_state.debug_tx = Some(debug_watch_tx);
 743        debug_watch_rx
 744    }
 745
 746    fn handle_project_event(
 747        &mut self,
 748        project: Entity<Project>,
 749        event: &project::Event,
 750        cx: &mut Context<Self>,
 751    ) {
 752        // TODO [zeta2] init with recent paths
 753        match event {
 754            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 755                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 756                    return;
 757                };
 758                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 759                if let Some(path) = path {
 760                    if let Some(ix) = project_state
 761                        .recent_paths
 762                        .iter()
 763                        .position(|probe| probe == &path)
 764                    {
 765                        project_state.recent_paths.remove(ix);
 766                    }
 767                    project_state.recent_paths.push_front(path);
 768                }
 769            }
 770            project::Event::DiagnosticsUpdated { .. } => {
 771                if cx.has_flag::<Zeta2FeatureFlag>() {
 772                    self.refresh_prediction_from_diagnostics(project, cx);
 773                }
 774            }
 775            _ => (),
 776        }
 777    }
 778
 779    fn register_buffer_impl<'a>(
 780        project_state: &'a mut ProjectState,
 781        buffer: &Entity<Buffer>,
 782        project: &Entity<Project>,
 783        cx: &mut Context<Self>,
 784    ) -> &'a mut RegisteredBuffer {
 785        let buffer_id = buffer.entity_id();
 786
 787        if let Some(file) = buffer.read(cx).file() {
 788            let worktree_id = file.worktree_id(cx);
 789            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 790                project_state
 791                    .license_detection_watchers
 792                    .entry(worktree_id)
 793                    .or_insert_with(|| {
 794                        let project_entity_id = project.entity_id();
 795                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 796                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 797                            else {
 798                                return;
 799                            };
 800                            project_state
 801                                .license_detection_watchers
 802                                .remove(&worktree_id);
 803                        })
 804                        .detach();
 805                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 806                    });
 807            }
 808        }
 809
 810        match project_state.registered_buffers.entry(buffer_id) {
 811            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 812            hash_map::Entry::Vacant(entry) => {
 813                let snapshot = buffer.read(cx).snapshot();
 814                let project_entity_id = project.entity_id();
 815                entry.insert(RegisteredBuffer {
 816                    snapshot,
 817                    last_position: None,
 818                    _subscriptions: [
 819                        cx.subscribe(buffer, {
 820                            let project = project.downgrade();
 821                            move |this, buffer, event, cx| {
 822                                if let language::BufferEvent::Edited = event
 823                                    && let Some(project) = project.upgrade()
 824                                {
 825                                    this.report_changes_for_buffer(&buffer, &project, cx);
 826                                }
 827                            }
 828                        }),
 829                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 830                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 831                            else {
 832                                return;
 833                            };
 834                            project_state.registered_buffers.remove(&buffer_id);
 835                        }),
 836                    ],
 837                })
 838            }
 839        }
 840    }
 841
 842    fn report_changes_for_buffer(
 843        &mut self,
 844        buffer: &Entity<Buffer>,
 845        project: &Entity<Project>,
 846        cx: &mut Context<Self>,
 847    ) {
 848        let project_state = self.get_or_init_project(project, cx);
 849        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 850
 851        let new_snapshot = buffer.read(cx).snapshot();
 852        if new_snapshot.version == registered_buffer.snapshot.version {
 853            return;
 854        }
 855
 856        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 857        let end_edit_anchor = new_snapshot
 858            .anchored_edits_since::<Point>(&old_snapshot.version)
 859            .last()
 860            .map(|(_, range)| range.end);
 861        let events = &mut project_state.events;
 862
 863        if let Some(LastEvent {
 864            new_snapshot: last_new_snapshot,
 865            end_edit_anchor: last_end_edit_anchor,
 866            ..
 867        }) = project_state.last_event.as_mut()
 868        {
 869            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 870                == last_new_snapshot.remote_id()
 871                && old_snapshot.version == last_new_snapshot.version;
 872
 873            let should_coalesce = is_next_snapshot_of_same_buffer
 874                && end_edit_anchor
 875                    .as_ref()
 876                    .zip(last_end_edit_anchor.as_ref())
 877                    .is_some_and(|(a, b)| {
 878                        let a = a.to_point(&new_snapshot);
 879                        let b = b.to_point(&new_snapshot);
 880                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 881                    });
 882
 883            if should_coalesce {
 884                *last_end_edit_anchor = end_edit_anchor;
 885                *last_new_snapshot = new_snapshot;
 886                return;
 887            }
 888        }
 889
 890        if events.len() + 1 >= EVENT_COUNT_MAX {
 891            events.pop_front();
 892        }
 893
 894        if let Some(event) = project_state.last_event.take() {
 895            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 896        }
 897
 898        project_state.last_event = Some(LastEvent {
 899            old_snapshot,
 900            new_snapshot,
 901            end_edit_anchor,
 902        });
 903    }
 904
 905    fn prediction_at(
 906        &mut self,
 907        buffer: &Entity<Buffer>,
 908        position: Option<language::Anchor>,
 909        project: &Entity<Project>,
 910        cx: &App,
 911    ) -> Option<BufferEditPrediction<'_>> {
 912        let project_state = self.projects.get_mut(&project.entity_id())?;
 913        if let Some(position) = position
 914            && let Some(buffer) = project_state
 915                .registered_buffers
 916                .get_mut(&buffer.entity_id())
 917        {
 918            buffer.last_position = Some(position);
 919        }
 920
 921        let CurrentEditPrediction {
 922            requested_by,
 923            prediction,
 924            ..
 925        } = project_state.current_prediction.as_ref()?;
 926
 927        if prediction.targets_buffer(buffer.read(cx)) {
 928            Some(BufferEditPrediction::Local { prediction })
 929        } else {
 930            let show_jump = match requested_by {
 931                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 932                    requested_by_buffer_id == &buffer.entity_id()
 933                }
 934                PredictionRequestedBy::DiagnosticsUpdate => true,
 935            };
 936
 937            if show_jump {
 938                Some(BufferEditPrediction::Jump { prediction })
 939            } else {
 940                None
 941            }
 942        }
 943    }
 944
 945    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 946        match self.edit_prediction_model {
 947            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
 948            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
 949        }
 950
 951        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 952            return;
 953        };
 954
 955        let Some(prediction) = project_state.current_prediction.take() else {
 956            return;
 957        };
 958        let request_id = prediction.prediction.id.to_string();
 959        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 960            project_state.cancel_pending_prediction(pending_prediction, cx);
 961        }
 962
 963        let client = self.client.clone();
 964        let llm_token = self.llm_token.clone();
 965        let app_version = AppVersion::global(cx);
 966        cx.spawn(async move |this, cx| {
 967            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 968                http_client::Url::parse(&predict_edits_url)?
 969            } else {
 970                client
 971                    .http_client()
 972                    .build_zed_llm_url("/predict_edits/accept", &[])?
 973            };
 974
 975            let response = cx
 976                .background_spawn(Self::send_api_request::<()>(
 977                    move |builder| {
 978                        let req = builder.uri(url.as_ref()).body(
 979                            serde_json::to_string(&AcceptEditPredictionBody {
 980                                request_id: request_id.clone(),
 981                            })?
 982                            .into(),
 983                        );
 984                        Ok(req?)
 985                    },
 986                    client,
 987                    llm_token,
 988                    app_version,
 989                ))
 990                .await;
 991
 992            Self::handle_api_response(&this, response, cx)?;
 993            anyhow::Ok(())
 994        })
 995        .detach_and_log_err(cx);
 996    }
 997
 998    async fn handle_rejected_predictions(
 999        rx: UnboundedReceiver<EditPredictionRejection>,
1000        client: Arc<Client>,
1001        llm_token: LlmApiToken,
1002        app_version: Version,
1003        background_executor: BackgroundExecutor,
1004    ) {
1005        let mut rx = std::pin::pin!(rx.peekable());
1006        let mut batched = Vec::new();
1007
1008        while let Some(rejection) = rx.next().await {
1009            batched.push(rejection);
1010
1011            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1012                select_biased! {
1013                    next = rx.as_mut().peek().fuse() => {
1014                        if next.is_some() {
1015                            continue;
1016                        }
1017                    }
1018                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1019                }
1020            }
1021
1022            let url = client
1023                .http_client()
1024                .build_zed_llm_url("/predict_edits/reject", &[])
1025                .unwrap();
1026
1027            let flush_count = batched
1028                .len()
1029                // in case items have accumulated after failure
1030                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1031            let start = batched.len() - flush_count;
1032
1033            let body = RejectEditPredictionsBodyRef {
1034                rejections: &batched[start..],
1035            };
1036
1037            let result = Self::send_api_request::<()>(
1038                |builder| {
1039                    let req = builder
1040                        .uri(url.as_ref())
1041                        .body(serde_json::to_string(&body)?.into());
1042                    anyhow::Ok(req?)
1043                },
1044                client.clone(),
1045                llm_token.clone(),
1046                app_version.clone(),
1047            )
1048            .await;
1049
1050            if result.log_err().is_some() {
1051                batched.drain(start..);
1052            }
1053        }
1054    }
1055
1056    fn reject_current_prediction(
1057        &mut self,
1058        reason: EditPredictionRejectReason,
1059        project: &Entity<Project>,
1060    ) {
1061        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1062            project_state.pending_predictions.clear();
1063            if let Some(prediction) = project_state.current_prediction.take() {
1064                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1065            }
1066        };
1067    }
1068
1069    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1070        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1071            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1072                if !current_prediction.was_shown {
1073                    current_prediction.was_shown = true;
1074                    self.shown_predictions
1075                        .push_front(current_prediction.prediction.clone());
1076                    if self.shown_predictions.len() > 50 {
1077                        let completion = self.shown_predictions.pop_back().unwrap();
1078                        self.rated_predictions.remove(&completion.id);
1079                    }
1080                }
1081            }
1082        }
1083    }
1084
1085    fn reject_prediction(
1086        &mut self,
1087        prediction_id: EditPredictionId,
1088        reason: EditPredictionRejectReason,
1089        was_shown: bool,
1090    ) {
1091        match self.edit_prediction_model {
1092            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1093            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1094        }
1095
1096        self.reject_predictions_tx
1097            .unbounded_send(EditPredictionRejection {
1098                request_id: prediction_id.to_string(),
1099                reason,
1100                was_shown,
1101            })
1102            .log_err();
1103    }
1104
1105    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1106        self.projects
1107            .get(&project.entity_id())
1108            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1109    }
1110
1111    pub fn refresh_prediction_from_buffer(
1112        &mut self,
1113        project: Entity<Project>,
1114        buffer: Entity<Buffer>,
1115        position: language::Anchor,
1116        cx: &mut Context<Self>,
1117    ) {
1118        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1119            let Some(request_task) = this
1120                .update(cx, |this, cx| {
1121                    this.request_prediction(
1122                        &project,
1123                        &buffer,
1124                        position,
1125                        PredictEditsRequestTrigger::Other,
1126                        cx,
1127                    )
1128                })
1129                .log_err()
1130            else {
1131                return Task::ready(anyhow::Ok(None));
1132            };
1133
1134            cx.spawn(async move |_cx| {
1135                request_task.await.map(|prediction_result| {
1136                    prediction_result.map(|prediction_result| {
1137                        (
1138                            prediction_result,
1139                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1140                        )
1141                    })
1142                })
1143            })
1144        })
1145    }
1146
1147    pub fn refresh_prediction_from_diagnostics(
1148        &mut self,
1149        project: Entity<Project>,
1150        cx: &mut Context<Self>,
1151    ) {
1152        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1153            return;
1154        };
1155
1156        // Prefer predictions from buffer
1157        if project_state.current_prediction.is_some() {
1158            return;
1159        };
1160
1161        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1162            let Some((active_buffer, snapshot, cursor_point)) = this
1163                .read_with(cx, |this, cx| {
1164                    let project_state = this.projects.get(&project.entity_id())?;
1165                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1166                    let snapshot = buffer.read(cx).snapshot();
1167
1168                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1169                        return None;
1170                    }
1171
1172                    let cursor_point = position
1173                        .map(|pos| pos.to_point(&snapshot))
1174                        .unwrap_or_default();
1175
1176                    Some((buffer, snapshot, cursor_point))
1177                })
1178                .log_err()
1179                .flatten()
1180            else {
1181                return Task::ready(anyhow::Ok(None));
1182            };
1183
1184            cx.spawn(async move |cx| {
1185                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1186                    active_buffer,
1187                    &snapshot,
1188                    Default::default(),
1189                    cursor_point,
1190                    &project,
1191                    cx,
1192                )
1193                .await?
1194                else {
1195                    return anyhow::Ok(None);
1196                };
1197
1198                let Some(prediction_result) = this
1199                    .update(cx, |this, cx| {
1200                        this.request_prediction(
1201                            &project,
1202                            &jump_buffer,
1203                            jump_position,
1204                            PredictEditsRequestTrigger::Diagnostics,
1205                            cx,
1206                        )
1207                    })?
1208                    .await?
1209                else {
1210                    return anyhow::Ok(None);
1211                };
1212
1213                this.update(cx, |this, cx| {
1214                    Some((
1215                        if this
1216                            .get_or_init_project(&project, cx)
1217                            .current_prediction
1218                            .is_none()
1219                        {
1220                            prediction_result
1221                        } else {
1222                            EditPredictionResult {
1223                                id: prediction_result.id,
1224                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1225                            }
1226                        },
1227                        PredictionRequestedBy::DiagnosticsUpdate,
1228                    ))
1229                })
1230            })
1231        });
1232    }
1233
1234    fn predictions_enabled_at(
1235        snapshot: &BufferSnapshot,
1236        position: Option<language::Anchor>,
1237        cx: &App,
1238    ) -> bool {
1239        let file = snapshot.file();
1240        let all_settings = all_language_settings(file, cx);
1241        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1242            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1243        {
1244            return false;
1245        }
1246
1247        if let Some(last_position) = position {
1248            let settings = snapshot.settings_at(last_position, cx);
1249
1250            if !settings.edit_predictions_disabled_in.is_empty()
1251                && let Some(scope) = snapshot.language_scope_at(last_position)
1252                && let Some(scope_name) = scope.override_name()
1253                && settings
1254                    .edit_predictions_disabled_in
1255                    .iter()
1256                    .any(|s| s == scope_name)
1257            {
1258                return false;
1259            }
1260        }
1261
1262        true
1263    }
1264
1265    #[cfg(not(test))]
1266    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1267    #[cfg(test)]
1268    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1269
1270    fn queue_prediction_refresh(
1271        &mut self,
1272        project: Entity<Project>,
1273        throttle_entity: EntityId,
1274        cx: &mut Context<Self>,
1275        do_refresh: impl FnOnce(
1276            WeakEntity<Self>,
1277            &mut AsyncApp,
1278        )
1279            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1280        + 'static,
1281    ) {
1282        let project_state = self.get_or_init_project(&project, cx);
1283        let pending_prediction_id = project_state.next_pending_prediction_id;
1284        project_state.next_pending_prediction_id += 1;
1285        let last_request = project_state.last_prediction_refresh;
1286
1287        let task = cx.spawn(async move |this, cx| {
1288            if let Some((last_entity, last_timestamp)) = last_request
1289                && throttle_entity == last_entity
1290                && let Some(timeout) =
1291                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1292            {
1293                cx.background_executor().timer(timeout).await;
1294            }
1295
1296            // If this task was cancelled before the throttle timeout expired,
1297            // do not perform a request.
1298            let mut is_cancelled = true;
1299            this.update(cx, |this, cx| {
1300                let project_state = this.get_or_init_project(&project, cx);
1301                if !project_state
1302                    .cancelled_predictions
1303                    .remove(&pending_prediction_id)
1304                {
1305                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1306                    is_cancelled = false;
1307                }
1308            })
1309            .ok();
1310            if is_cancelled {
1311                return None;
1312            }
1313
1314            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1315            let new_prediction_id = new_prediction_result
1316                .as_ref()
1317                .map(|(prediction, _)| prediction.id.clone());
1318
1319            // When a prediction completes, remove it from the pending list, and cancel
1320            // any pending predictions that were enqueued before it.
1321            this.update(cx, |this, cx| {
1322                let project_state = this.get_or_init_project(&project, cx);
1323
1324                let is_cancelled = project_state
1325                    .cancelled_predictions
1326                    .remove(&pending_prediction_id);
1327
1328                let new_current_prediction = if !is_cancelled
1329                    && let Some((prediction_result, requested_by)) = new_prediction_result
1330                {
1331                    match prediction_result.prediction {
1332                        Ok(prediction) => {
1333                            let new_prediction = CurrentEditPrediction {
1334                                requested_by,
1335                                prediction,
1336                                was_shown: false,
1337                            };
1338
1339                            if let Some(current_prediction) =
1340                                project_state.current_prediction.as_ref()
1341                            {
1342                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1343                                {
1344                                    this.reject_current_prediction(
1345                                        EditPredictionRejectReason::Replaced,
1346                                        &project,
1347                                    );
1348
1349                                    Some(new_prediction)
1350                                } else {
1351                                    this.reject_prediction(
1352                                        new_prediction.prediction.id,
1353                                        EditPredictionRejectReason::CurrentPreferred,
1354                                        false,
1355                                    );
1356                                    None
1357                                }
1358                            } else {
1359                                Some(new_prediction)
1360                            }
1361                        }
1362                        Err(reject_reason) => {
1363                            this.reject_prediction(prediction_result.id, reject_reason, false);
1364                            None
1365                        }
1366                    }
1367                } else {
1368                    None
1369                };
1370
1371                let project_state = this.get_or_init_project(&project, cx);
1372
1373                if let Some(new_prediction) = new_current_prediction {
1374                    project_state.current_prediction = Some(new_prediction);
1375                }
1376
1377                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1378                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1379                    if pending_prediction.id == pending_prediction_id {
1380                        pending_predictions.remove(ix);
1381                        for pending_prediction in pending_predictions.drain(0..ix) {
1382                            project_state.cancel_pending_prediction(pending_prediction, cx)
1383                        }
1384                        break;
1385                    }
1386                }
1387                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1388                cx.notify();
1389            })
1390            .ok();
1391
1392            new_prediction_id
1393        });
1394
1395        if project_state.pending_predictions.len() <= 1 {
1396            project_state.pending_predictions.push(PendingPrediction {
1397                id: pending_prediction_id,
1398                task,
1399            });
1400        } else if project_state.pending_predictions.len() == 2 {
1401            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1402            project_state.pending_predictions.push(PendingPrediction {
1403                id: pending_prediction_id,
1404                task,
1405            });
1406            project_state.cancel_pending_prediction(pending_prediction, cx);
1407        }
1408    }
1409
1410    pub fn request_prediction(
1411        &mut self,
1412        project: &Entity<Project>,
1413        active_buffer: &Entity<Buffer>,
1414        position: language::Anchor,
1415        trigger: PredictEditsRequestTrigger,
1416        cx: &mut Context<Self>,
1417    ) -> Task<Result<Option<EditPredictionResult>>> {
1418        self.request_prediction_internal(
1419            project.clone(),
1420            active_buffer.clone(),
1421            position,
1422            trigger,
1423            cx.has_flag::<Zeta2FeatureFlag>(),
1424            cx,
1425        )
1426    }
1427
1428    fn request_prediction_internal(
1429        &mut self,
1430        project: Entity<Project>,
1431        active_buffer: Entity<Buffer>,
1432        position: language::Anchor,
1433        trigger: PredictEditsRequestTrigger,
1434        allow_jump: bool,
1435        cx: &mut Context<Self>,
1436    ) -> Task<Result<Option<EditPredictionResult>>> {
1437        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1438
1439        self.get_or_init_project(&project, cx);
1440        let project_state = self.projects.get(&project.entity_id()).unwrap();
1441        let events = project_state.events(cx);
1442        let has_events = !events.is_empty();
1443        let debug_tx = project_state.debug_tx.clone();
1444
1445        let snapshot = active_buffer.read(cx).snapshot();
1446        let cursor_point = position.to_point(&snapshot);
1447        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1448        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1449        let diagnostic_search_range =
1450            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1451
1452        let related_files = if self.use_context {
1453            self.context_for_project(&project, cx)
1454        } else {
1455            Vec::new().into()
1456        };
1457
1458        let inputs = EditPredictionModelInput {
1459            project: project.clone(),
1460            buffer: active_buffer.clone(),
1461            snapshot: snapshot.clone(),
1462            position,
1463            events,
1464            related_files,
1465            recent_paths: project_state.recent_paths.clone(),
1466            trigger,
1467            diagnostic_search_range: diagnostic_search_range.clone(),
1468            debug_tx,
1469        };
1470
1471        let task = match self.edit_prediction_model {
1472            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1473            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1474            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1475            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1476        };
1477
1478        cx.spawn(async move |this, cx| {
1479            let prediction = task.await?;
1480
1481            if prediction.is_none() && allow_jump {
1482                let cursor_point = position.to_point(&snapshot);
1483                if has_events
1484                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1485                        active_buffer.clone(),
1486                        &snapshot,
1487                        diagnostic_search_range,
1488                        cursor_point,
1489                        &project,
1490                        cx,
1491                    )
1492                    .await?
1493                {
1494                    return this
1495                        .update(cx, |this, cx| {
1496                            this.request_prediction_internal(
1497                                project,
1498                                jump_buffer,
1499                                jump_position,
1500                                trigger,
1501                                false,
1502                                cx,
1503                            )
1504                        })?
1505                        .await;
1506                }
1507
1508                return anyhow::Ok(None);
1509            }
1510
1511            Ok(prediction)
1512        })
1513    }
1514
1515    async fn next_diagnostic_location(
1516        active_buffer: Entity<Buffer>,
1517        active_buffer_snapshot: &BufferSnapshot,
1518        active_buffer_diagnostic_search_range: Range<Point>,
1519        active_buffer_cursor_point: Point,
1520        project: &Entity<Project>,
1521        cx: &mut AsyncApp,
1522    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1523        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1524        let mut jump_location = active_buffer_snapshot
1525            .diagnostic_groups(None)
1526            .into_iter()
1527            .filter_map(|(_, group)| {
1528                let range = &group.entries[group.primary_ix]
1529                    .range
1530                    .to_point(&active_buffer_snapshot);
1531                if range.overlaps(&active_buffer_diagnostic_search_range) {
1532                    None
1533                } else {
1534                    Some(range.start)
1535                }
1536            })
1537            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1538            .map(|position| {
1539                (
1540                    active_buffer.clone(),
1541                    active_buffer_snapshot.anchor_before(position),
1542                )
1543            });
1544
1545        if jump_location.is_none() {
1546            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1547                let file = buffer.file()?;
1548
1549                Some(ProjectPath {
1550                    worktree_id: file.worktree_id(cx),
1551                    path: file.path().clone(),
1552                })
1553            })?;
1554
1555            let buffer_task = project.update(cx, |project, cx| {
1556                let (path, _, _) = project
1557                    .diagnostic_summaries(false, cx)
1558                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1559                    .max_by_key(|(path, _, _)| {
1560                        // find the buffer with errors that shares most parent directories
1561                        path.path
1562                            .components()
1563                            .zip(
1564                                active_buffer_path
1565                                    .as_ref()
1566                                    .map(|p| p.path.components())
1567                                    .unwrap_or_default(),
1568                            )
1569                            .take_while(|(a, b)| a == b)
1570                            .count()
1571                    })?;
1572
1573                Some(project.open_buffer(path, cx))
1574            })?;
1575
1576            if let Some(buffer_task) = buffer_task {
1577                let closest_buffer = buffer_task.await?;
1578
1579                jump_location = closest_buffer
1580                    .read_with(cx, |buffer, _cx| {
1581                        buffer
1582                            .buffer_diagnostics(None)
1583                            .into_iter()
1584                            .min_by_key(|entry| entry.diagnostic.severity)
1585                            .map(|entry| entry.range.start)
1586                    })?
1587                    .map(|position| (closest_buffer, position));
1588            }
1589        }
1590
1591        anyhow::Ok(jump_location)
1592    }
1593
1594    async fn send_raw_llm_request(
1595        request: open_ai::Request,
1596        client: Arc<Client>,
1597        llm_token: LlmApiToken,
1598        app_version: Version,
1599        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1600        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1601    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1602        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1603            http_client::Url::parse(&predict_edits_url)?
1604        } else {
1605            client
1606                .http_client()
1607                .build_zed_llm_url("/predict_edits/raw", &[])?
1608        };
1609
1610        #[cfg(feature = "cli-support")]
1611        let cache_key = if let Some(cache) = eval_cache {
1612            use collections::FxHasher;
1613            use std::hash::{Hash, Hasher};
1614
1615            let mut hasher = FxHasher::default();
1616            url.hash(&mut hasher);
1617            let request_str = serde_json::to_string_pretty(&request)?;
1618            request_str.hash(&mut hasher);
1619            let hash = hasher.finish();
1620
1621            let key = (eval_cache_kind, hash);
1622            if let Some(response_str) = cache.read(key) {
1623                return Ok((serde_json::from_str(&response_str)?, None));
1624            }
1625
1626            Some((cache, request_str, key))
1627        } else {
1628            None
1629        };
1630
1631        let (response, usage) = Self::send_api_request(
1632            |builder| {
1633                let req = builder
1634                    .uri(url.as_ref())
1635                    .body(serde_json::to_string(&request)?.into());
1636                Ok(req?)
1637            },
1638            client,
1639            llm_token,
1640            app_version,
1641        )
1642        .await?;
1643
1644        #[cfg(feature = "cli-support")]
1645        if let Some((cache, request, key)) = cache_key {
1646            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1647        }
1648
1649        Ok((response, usage))
1650    }
1651
1652    fn handle_api_response<T>(
1653        this: &WeakEntity<Self>,
1654        response: Result<(T, Option<EditPredictionUsage>)>,
1655        cx: &mut gpui::AsyncApp,
1656    ) -> Result<T> {
1657        match response {
1658            Ok((data, usage)) => {
1659                if let Some(usage) = usage {
1660                    this.update(cx, |this, cx| {
1661                        this.user_store.update(cx, |user_store, cx| {
1662                            user_store.update_edit_prediction_usage(usage, cx);
1663                        });
1664                    })
1665                    .ok();
1666                }
1667                Ok(data)
1668            }
1669            Err(err) => {
1670                if err.is::<ZedUpdateRequiredError>() {
1671                    cx.update(|cx| {
1672                        this.update(cx, |this, _cx| {
1673                            this.update_required = true;
1674                        })
1675                        .ok();
1676
1677                        let error_message: SharedString = err.to_string().into();
1678                        show_app_notification(
1679                            NotificationId::unique::<ZedUpdateRequiredError>(),
1680                            cx,
1681                            move |cx| {
1682                                cx.new(|cx| {
1683                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1684                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1685                                })
1686                            },
1687                        );
1688                    })
1689                    .ok();
1690                }
1691                Err(err)
1692            }
1693        }
1694    }
1695
1696    async fn send_api_request<Res>(
1697        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1698        client: Arc<Client>,
1699        llm_token: LlmApiToken,
1700        app_version: Version,
1701    ) -> Result<(Res, Option<EditPredictionUsage>)>
1702    where
1703        Res: DeserializeOwned,
1704    {
1705        let http_client = client.http_client();
1706        let mut token = llm_token.acquire(&client).await?;
1707        let mut did_retry = false;
1708
1709        loop {
1710            let request_builder = http_client::Request::builder().method(Method::POST);
1711
1712            let request = build(
1713                request_builder
1714                    .header("Content-Type", "application/json")
1715                    .header("Authorization", format!("Bearer {}", token))
1716                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1717            )?;
1718
1719            let mut response = http_client.send(request).await?;
1720
1721            if let Some(minimum_required_version) = response
1722                .headers()
1723                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1724                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1725            {
1726                anyhow::ensure!(
1727                    app_version >= minimum_required_version,
1728                    ZedUpdateRequiredError {
1729                        minimum_version: minimum_required_version
1730                    }
1731                );
1732            }
1733
1734            if response.status().is_success() {
1735                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1736
1737                let mut body = Vec::new();
1738                response.body_mut().read_to_end(&mut body).await?;
1739                return Ok((serde_json::from_slice(&body)?, usage));
1740            } else if !did_retry
1741                && response
1742                    .headers()
1743                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1744                    .is_some()
1745            {
1746                did_retry = true;
1747                token = llm_token.refresh(&client).await?;
1748            } else {
1749                let mut body = String::new();
1750                response.body_mut().read_to_string(&mut body).await?;
1751                anyhow::bail!(
1752                    "Request failed with status: {:?}\nBody: {}",
1753                    response.status(),
1754                    body
1755                );
1756            }
1757        }
1758    }
1759
1760    pub fn refresh_context(
1761        &mut self,
1762        project: &Entity<Project>,
1763        buffer: &Entity<language::Buffer>,
1764        cursor_position: language::Anchor,
1765        cx: &mut Context<Self>,
1766    ) {
1767        if self.use_context {
1768            self.get_or_init_project(project, cx)
1769                .context
1770                .update(cx, |store, cx| {
1771                    store.refresh(buffer.clone(), cursor_position, cx);
1772                });
1773        }
1774    }
1775
1776    #[cfg(feature = "cli-support")]
1777    pub fn set_context_for_buffer(
1778        &mut self,
1779        project: &Entity<Project>,
1780        related_files: Vec<RelatedFile>,
1781        cx: &mut Context<Self>,
1782    ) {
1783        self.get_or_init_project(project, cx)
1784            .context
1785            .update(cx, |store, _| {
1786                store.set_related_files(related_files);
1787            });
1788    }
1789
1790    fn is_file_open_source(
1791        &self,
1792        project: &Entity<Project>,
1793        file: &Arc<dyn File>,
1794        cx: &App,
1795    ) -> bool {
1796        if !file.is_local() || file.is_private() {
1797            return false;
1798        }
1799        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1800            return false;
1801        };
1802        project_state
1803            .license_detection_watchers
1804            .get(&file.worktree_id(cx))
1805            .as_ref()
1806            .is_some_and(|watcher| watcher.is_project_open_source())
1807    }
1808
1809    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1810        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1811    }
1812
1813    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1814        if !self.data_collection_choice.is_enabled() {
1815            return false;
1816        }
1817        events.iter().all(|event| {
1818            matches!(
1819                event.as_ref(),
1820                zeta_prompt::Event::BufferChange {
1821                    in_open_source_repo: true,
1822                    ..
1823                }
1824            )
1825        })
1826    }
1827
1828    fn load_data_collection_choice() -> DataCollectionChoice {
1829        let choice = KEY_VALUE_STORE
1830            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1831            .log_err()
1832            .flatten();
1833
1834        match choice.as_deref() {
1835            Some("true") => DataCollectionChoice::Enabled,
1836            Some("false") => DataCollectionChoice::Disabled,
1837            Some(_) => {
1838                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1839                DataCollectionChoice::NotAnswered
1840            }
1841            None => DataCollectionChoice::NotAnswered,
1842        }
1843    }
1844
1845    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1846        self.data_collection_choice = self.data_collection_choice.toggle();
1847        let new_choice = self.data_collection_choice;
1848        db::write_and_log(cx, move || {
1849            KEY_VALUE_STORE.write_kvp(
1850                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1851                new_choice.is_enabled().to_string(),
1852            )
1853        });
1854    }
1855
1856    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1857        self.shown_predictions.iter()
1858    }
1859
1860    pub fn shown_completions_len(&self) -> usize {
1861        self.shown_predictions.len()
1862    }
1863
1864    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1865        self.rated_predictions.contains(id)
1866    }
1867
1868    pub fn rate_prediction(
1869        &mut self,
1870        prediction: &EditPrediction,
1871        rating: EditPredictionRating,
1872        feedback: String,
1873        cx: &mut Context<Self>,
1874    ) {
1875        self.rated_predictions.insert(prediction.id.clone());
1876        telemetry::event!(
1877            "Edit Prediction Rated",
1878            rating,
1879            inputs = prediction.inputs,
1880            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1881            feedback
1882        );
1883        self.client.telemetry().flush_events().detach();
1884        cx.notify();
1885    }
1886
1887    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1888        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1889            && all_language_settings(None, cx).edit_predictions.use_context;
1890    }
1891}
1892
1893#[derive(Error, Debug)]
1894#[error(
1895    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1896)]
1897pub struct ZedUpdateRequiredError {
1898    minimum_version: Version,
1899}
1900
1901#[cfg(feature = "cli-support")]
1902pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1903
1904#[cfg(feature = "cli-support")]
1905#[derive(Debug, Clone, Copy, PartialEq)]
1906pub enum EvalCacheEntryKind {
1907    Context,
1908    Search,
1909    Prediction,
1910}
1911
1912#[cfg(feature = "cli-support")]
1913impl std::fmt::Display for EvalCacheEntryKind {
1914    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1915        match self {
1916            EvalCacheEntryKind::Search => write!(f, "search"),
1917            EvalCacheEntryKind::Context => write!(f, "context"),
1918            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1919        }
1920    }
1921}
1922
1923#[cfg(feature = "cli-support")]
1924pub trait EvalCache: Send + Sync {
1925    fn read(&self, key: EvalCacheKey) -> Option<String>;
1926    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1927}
1928
1929#[derive(Debug, Clone, Copy)]
1930pub enum DataCollectionChoice {
1931    NotAnswered,
1932    Enabled,
1933    Disabled,
1934}
1935
1936impl DataCollectionChoice {
1937    pub fn is_enabled(self) -> bool {
1938        match self {
1939            Self::Enabled => true,
1940            Self::NotAnswered | Self::Disabled => false,
1941        }
1942    }
1943
1944    pub fn is_answered(self) -> bool {
1945        match self {
1946            Self::Enabled | Self::Disabled => true,
1947            Self::NotAnswered => false,
1948        }
1949    }
1950
1951    #[must_use]
1952    pub fn toggle(&self) -> DataCollectionChoice {
1953        match self {
1954            Self::Enabled => Self::Disabled,
1955            Self::Disabled => Self::Enabled,
1956            Self::NotAnswered => Self::Enabled,
1957        }
1958    }
1959}
1960
1961impl From<bool> for DataCollectionChoice {
1962    fn from(value: bool) -> Self {
1963        match value {
1964            true => DataCollectionChoice::Enabled,
1965            false => DataCollectionChoice::Disabled,
1966        }
1967    }
1968}
1969
1970struct ZedPredictUpsell;
1971
1972impl Dismissable for ZedPredictUpsell {
1973    const KEY: &'static str = "dismissed-edit-predict-upsell";
1974
1975    fn dismissed() -> bool {
1976        // To make this backwards compatible with older versions of Zed, we
1977        // check if the user has seen the previous Edit Prediction Onboarding
1978        // before, by checking the data collection choice which was written to
1979        // the database once the user clicked on "Accept and Enable"
1980        if KEY_VALUE_STORE
1981            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1982            .log_err()
1983            .is_some_and(|s| s.is_some())
1984        {
1985            return true;
1986        }
1987
1988        KEY_VALUE_STORE
1989            .read_kvp(Self::KEY)
1990            .log_err()
1991            .is_some_and(|s| s.is_some())
1992    }
1993}
1994
1995pub fn should_show_upsell_modal() -> bool {
1996    !ZedPredictUpsell::dismissed()
1997}
1998
1999pub fn init(cx: &mut App) {
2000    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2001        workspace.register_action(
2002            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2003                ZedPredictModal::toggle(
2004                    workspace,
2005                    workspace.user_store().clone(),
2006                    workspace.client().clone(),
2007                    window,
2008                    cx,
2009                )
2010            },
2011        );
2012
2013        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2014            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2015                settings
2016                    .project
2017                    .all_languages
2018                    .features
2019                    .get_or_insert_default()
2020                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2021            });
2022        });
2023    })
2024    .detach();
2025}