edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use workspace::Workspace;
  38
  39use std::ops::Range;
  40use std::path::Path;
  41use std::rc::Rc;
  42use std::str::FromStr as _;
  43use std::sync::{Arc, LazyLock};
  44use std::time::{Duration, Instant};
  45use std::{env, mem};
  46use thiserror::Error;
  47use util::{RangeExt as _, ResultExt as _};
  48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  49
  50mod cursor_excerpt;
  51mod license_detection;
  52pub mod mercury;
  53mod onboarding_modal;
  54pub mod open_ai_response;
  55mod prediction;
  56pub mod sweep_ai;
  57
  58#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
  59pub mod udiff;
  60
  61mod zed_edit_prediction_delegate;
  62pub mod zeta1;
  63pub mod zeta2;
  64
  65#[cfg(test)]
  66mod edit_prediction_tests;
  67
  68use crate::license_detection::LicenseDetectionWatcher;
  69use crate::mercury::Mercury;
  70use crate::onboarding_modal::ZedPredictModal;
  71pub use crate::prediction::EditPrediction;
  72pub use crate::prediction::EditPredictionId;
  73use crate::prediction::EditPredictionResult;
  74pub use crate::sweep_ai::SweepAi;
  75pub use telemetry_events::EditPredictionRating;
  76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  77
  78actions!(
  79    edit_prediction,
  80    [
  81        /// Resets the edit prediction onboarding state.
  82        ResetOnboarding,
  83        /// Clears the edit prediction history.
  84        ClearHistory,
  85    ]
  86);
  87
  88/// Maximum number of events to track.
  89const EVENT_COUNT_MAX: usize = 6;
  90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  93
  94pub struct SweepFeatureFlag;
  95
  96impl FeatureFlag for SweepFeatureFlag {
  97    const NAME: &str = "sweep-ai";
  98}
  99
 100pub struct MercuryFeatureFlag;
 101
 102impl FeatureFlag for MercuryFeatureFlag {
 103    const NAME: &str = "mercury";
 104}
 105
 106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 107    context: EditPredictionExcerptOptions {
 108        max_bytes: 512,
 109        min_bytes: 128,
 110        target_before_cursor_over_total_bytes: 0.5,
 111    },
 112    prompt_format: PromptFormat::DEFAULT,
 113};
 114
 115static USE_OLLAMA: LazyLock<bool> =
 116    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 117
 118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 119    match env::var("ZED_ZETA2_MODEL").as_deref() {
 120        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 121        Ok(model) => model,
 122        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 123        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 124    }
 125    .to_string()
 126});
 127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 128    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 129        if *USE_OLLAMA {
 130            Some("http://localhost:11434/v1/chat/completions".into())
 131        } else {
 132            None
 133        }
 134    })
 135});
 136
 137pub struct Zeta2FeatureFlag;
 138
 139impl FeatureFlag for Zeta2FeatureFlag {
 140    const NAME: &'static str = "zeta2";
 141
 142    fn enabled_for_staff() -> bool {
 143        true
 144    }
 145}
 146
 147#[derive(Clone)]
 148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 149
 150impl Global for EditPredictionStoreGlobal {}
 151
 152pub struct EditPredictionStore {
 153    client: Arc<Client>,
 154    user_store: Entity<UserStore>,
 155    llm_token: LlmApiToken,
 156    _llm_token_subscription: Subscription,
 157    projects: HashMap<EntityId, ProjectState>,
 158    use_context: bool,
 159    options: ZetaOptions,
 160    update_required: bool,
 161    #[cfg(feature = "eval-support")]
 162    eval_cache: Option<Arc<dyn EvalCache>>,
 163    edit_prediction_model: EditPredictionModel,
 164    pub sweep_ai: SweepAi,
 165    pub mercury: Mercury,
 166    data_collection_choice: DataCollectionChoice,
 167    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 168    shown_predictions: VecDeque<EditPrediction>,
 169    rated_predictions: HashSet<EditPredictionId>,
 170}
 171
 172#[derive(Copy, Clone, Default, PartialEq, Eq)]
 173pub enum EditPredictionModel {
 174    #[default]
 175    Zeta1,
 176    Zeta2,
 177    Sweep,
 178    Mercury,
 179}
 180
 181pub struct EditPredictionModelInput {
 182    project: Entity<Project>,
 183    buffer: Entity<Buffer>,
 184    snapshot: BufferSnapshot,
 185    position: Anchor,
 186    events: Vec<Arc<zeta_prompt::Event>>,
 187    related_files: Arc<[RelatedFile]>,
 188    recent_paths: VecDeque<ProjectPath>,
 189    trigger: PredictEditsRequestTrigger,
 190    diagnostic_search_range: Range<Point>,
 191    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 192}
 193
 194#[derive(Debug, Clone, PartialEq)]
 195pub struct ZetaOptions {
 196    pub context: EditPredictionExcerptOptions,
 197    pub prompt_format: predict_edits_v3::PromptFormat,
 198}
 199
 200#[derive(Debug)]
 201pub enum DebugEvent {
 202    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 203    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 204    EditPredictionStarted(EditPredictionStartedDebugEvent),
 205    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 206}
 207
 208#[derive(Debug)]
 209pub struct ContextRetrievalStartedDebugEvent {
 210    pub project_entity_id: EntityId,
 211    pub timestamp: Instant,
 212    pub search_prompt: String,
 213}
 214
 215#[derive(Debug)]
 216pub struct ContextRetrievalFinishedDebugEvent {
 217    pub project_entity_id: EntityId,
 218    pub timestamp: Instant,
 219    pub metadata: Vec<(&'static str, SharedString)>,
 220}
 221
 222#[derive(Debug)]
 223pub struct EditPredictionStartedDebugEvent {
 224    pub buffer: WeakEntity<Buffer>,
 225    pub position: Anchor,
 226    pub prompt: Option<String>,
 227}
 228
 229#[derive(Debug)]
 230pub struct EditPredictionFinishedDebugEvent {
 231    pub buffer: WeakEntity<Buffer>,
 232    pub position: Anchor,
 233    pub model_output: Option<String>,
 234}
 235
 236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 237
 238struct ProjectState {
 239    events: VecDeque<Arc<zeta_prompt::Event>>,
 240    last_event: Option<LastEvent>,
 241    recent_paths: VecDeque<ProjectPath>,
 242    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 243    current_prediction: Option<CurrentEditPrediction>,
 244    next_pending_prediction_id: usize,
 245    pending_predictions: ArrayVec<PendingPrediction, 2>,
 246    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 247    last_prediction_refresh: Option<(EntityId, Instant)>,
 248    cancelled_predictions: HashSet<usize>,
 249    context: Entity<RelatedExcerptStore>,
 250    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 251    _subscription: gpui::Subscription,
 252}
 253
 254impl ProjectState {
 255    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 256        self.events
 257            .iter()
 258            .cloned()
 259            .chain(
 260                self.last_event
 261                    .as_ref()
 262                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 263            )
 264            .collect()
 265    }
 266
 267    fn cancel_pending_prediction(
 268        &mut self,
 269        pending_prediction: PendingPrediction,
 270        cx: &mut Context<EditPredictionStore>,
 271    ) {
 272        self.cancelled_predictions.insert(pending_prediction.id);
 273
 274        cx.spawn(async move |this, cx| {
 275            let Some(prediction_id) = pending_prediction.task.await else {
 276                return;
 277            };
 278
 279            this.update(cx, |this, _cx| {
 280                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 281            })
 282            .ok();
 283        })
 284        .detach()
 285    }
 286}
 287
 288#[derive(Debug, Clone)]
 289struct CurrentEditPrediction {
 290    pub requested_by: PredictionRequestedBy,
 291    pub prediction: EditPrediction,
 292    pub was_shown: bool,
 293}
 294
 295impl CurrentEditPrediction {
 296    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 297        let Some(new_edits) = self
 298            .prediction
 299            .interpolate(&self.prediction.buffer.read(cx))
 300        else {
 301            return false;
 302        };
 303
 304        if self.prediction.buffer != old_prediction.prediction.buffer {
 305            return true;
 306        }
 307
 308        let Some(old_edits) = old_prediction
 309            .prediction
 310            .interpolate(&old_prediction.prediction.buffer.read(cx))
 311        else {
 312            return true;
 313        };
 314
 315        let requested_by_buffer_id = self.requested_by.buffer_id();
 316
 317        // This reduces the occurrence of UI thrash from replacing edits
 318        //
 319        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 320        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 321            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 322            && old_edits.len() == 1
 323            && new_edits.len() == 1
 324        {
 325            let (old_range, old_text) = &old_edits[0];
 326            let (new_range, new_text) = &new_edits[0];
 327            new_range == old_range && new_text.starts_with(old_text.as_ref())
 328        } else {
 329            true
 330        }
 331    }
 332}
 333
 334#[derive(Debug, Clone)]
 335enum PredictionRequestedBy {
 336    DiagnosticsUpdate,
 337    Buffer(EntityId),
 338}
 339
 340impl PredictionRequestedBy {
 341    pub fn buffer_id(&self) -> Option<EntityId> {
 342        match self {
 343            PredictionRequestedBy::DiagnosticsUpdate => None,
 344            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 345        }
 346    }
 347}
 348
 349#[derive(Debug)]
 350struct PendingPrediction {
 351    id: usize,
 352    task: Task<Option<EditPredictionId>>,
 353}
 354
 355/// A prediction from the perspective of a buffer.
 356#[derive(Debug)]
 357enum BufferEditPrediction<'a> {
 358    Local { prediction: &'a EditPrediction },
 359    Jump { prediction: &'a EditPrediction },
 360}
 361
 362#[cfg(test)]
 363impl std::ops::Deref for BufferEditPrediction<'_> {
 364    type Target = EditPrediction;
 365
 366    fn deref(&self) -> &Self::Target {
 367        match self {
 368            BufferEditPrediction::Local { prediction } => prediction,
 369            BufferEditPrediction::Jump { prediction } => prediction,
 370        }
 371    }
 372}
 373
 374struct RegisteredBuffer {
 375    snapshot: BufferSnapshot,
 376    _subscriptions: [gpui::Subscription; 2],
 377}
 378
 379struct LastEvent {
 380    old_snapshot: BufferSnapshot,
 381    new_snapshot: BufferSnapshot,
 382    end_edit_anchor: Option<Anchor>,
 383}
 384
 385impl LastEvent {
 386    pub fn finalize(
 387        &self,
 388        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 389        cx: &App,
 390    ) -> Option<Arc<zeta_prompt::Event>> {
 391        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 392        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 393
 394        let file = self.new_snapshot.file();
 395        let old_file = self.old_snapshot.file();
 396
 397        let in_open_source_repo = [file, old_file].iter().all(|file| {
 398            file.is_some_and(|file| {
 399                license_detection_watchers
 400                    .get(&file.worktree_id(cx))
 401                    .is_some_and(|watcher| watcher.is_project_open_source())
 402            })
 403        });
 404
 405        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 406
 407        if path == old_path && diff.is_empty() {
 408            None
 409        } else {
 410            Some(Arc::new(zeta_prompt::Event::BufferChange {
 411                old_path,
 412                path,
 413                diff,
 414                in_open_source_repo,
 415                // TODO: Actually detect if this edit was predicted or not
 416                predicted: false,
 417            }))
 418        }
 419    }
 420}
 421
 422fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 423    if let Some(file) = snapshot.file() {
 424        file.full_path(cx).into()
 425    } else {
 426        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 427    }
 428}
 429
 430impl EditPredictionStore {
 431    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 432        cx.try_global::<EditPredictionStoreGlobal>()
 433            .map(|global| global.0.clone())
 434    }
 435
 436    pub fn global(
 437        client: &Arc<Client>,
 438        user_store: &Entity<UserStore>,
 439        cx: &mut App,
 440    ) -> Entity<Self> {
 441        cx.try_global::<EditPredictionStoreGlobal>()
 442            .map(|global| global.0.clone())
 443            .unwrap_or_else(|| {
 444                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 445                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 446                ep_store
 447            })
 448    }
 449
 450    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 451        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 452        let data_collection_choice = Self::load_data_collection_choice();
 453
 454        let llm_token = LlmApiToken::default();
 455
 456        let (reject_tx, reject_rx) = mpsc::unbounded();
 457        cx.background_spawn({
 458            let client = client.clone();
 459            let llm_token = llm_token.clone();
 460            let app_version = AppVersion::global(cx);
 461            let background_executor = cx.background_executor().clone();
 462            async move {
 463                Self::handle_rejected_predictions(
 464                    reject_rx,
 465                    client,
 466                    llm_token,
 467                    app_version,
 468                    background_executor,
 469                )
 470                .await
 471            }
 472        })
 473        .detach();
 474
 475        let mut this = Self {
 476            projects: HashMap::default(),
 477            client,
 478            user_store,
 479            options: DEFAULT_OPTIONS,
 480            use_context: false,
 481            llm_token,
 482            _llm_token_subscription: cx.subscribe(
 483                &refresh_llm_token_listener,
 484                |this, _listener, _event, cx| {
 485                    let client = this.client.clone();
 486                    let llm_token = this.llm_token.clone();
 487                    cx.spawn(async move |_this, _cx| {
 488                        llm_token.refresh(&client).await?;
 489                        anyhow::Ok(())
 490                    })
 491                    .detach_and_log_err(cx);
 492                },
 493            ),
 494            update_required: false,
 495            #[cfg(feature = "eval-support")]
 496            eval_cache: None,
 497            edit_prediction_model: EditPredictionModel::Zeta2,
 498            sweep_ai: SweepAi::new(cx),
 499            mercury: Mercury::new(cx),
 500            data_collection_choice,
 501            reject_predictions_tx: reject_tx,
 502            rated_predictions: Default::default(),
 503            shown_predictions: Default::default(),
 504        };
 505
 506        this.configure_context_retrieval(cx);
 507        let weak_this = cx.weak_entity();
 508        cx.on_flags_ready(move |_, cx| {
 509            weak_this
 510                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 511                .ok();
 512        })
 513        .detach();
 514        cx.observe_global::<SettingsStore>(|this, cx| {
 515            this.configure_context_retrieval(cx);
 516        })
 517        .detach();
 518
 519        this
 520    }
 521
 522    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 523        self.edit_prediction_model = model;
 524    }
 525
 526    pub fn has_sweep_api_token(&self) -> bool {
 527        self.sweep_ai
 528            .api_token
 529            .clone()
 530            .now_or_never()
 531            .flatten()
 532            .is_some()
 533    }
 534
 535    pub fn has_mercury_api_token(&self) -> bool {
 536        self.mercury
 537            .api_token
 538            .clone()
 539            .now_or_never()
 540            .flatten()
 541            .is_some()
 542    }
 543
 544    #[cfg(feature = "eval-support")]
 545    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 546        self.eval_cache = Some(cache);
 547    }
 548
 549    pub fn options(&self) -> &ZetaOptions {
 550        &self.options
 551    }
 552
 553    pub fn set_options(&mut self, options: ZetaOptions) {
 554        self.options = options;
 555    }
 556
 557    pub fn set_use_context(&mut self, use_context: bool) {
 558        self.use_context = use_context;
 559    }
 560
 561    pub fn clear_history(&mut self) {
 562        for project_state in self.projects.values_mut() {
 563            project_state.events.clear();
 564        }
 565    }
 566
 567    pub fn edit_history_for_project(
 568        &self,
 569        project: &Entity<Project>,
 570    ) -> Vec<Arc<zeta_prompt::Event>> {
 571        self.projects
 572            .get(&project.entity_id())
 573            .map(|project_state| project_state.events.iter().cloned().collect())
 574            .unwrap_or_default()
 575    }
 576
 577    pub fn context_for_project<'a>(
 578        &'a self,
 579        project: &Entity<Project>,
 580        cx: &'a App,
 581    ) -> Arc<[RelatedFile]> {
 582        self.projects
 583            .get(&project.entity_id())
 584            .map(|project| project.context.read(cx).related_files())
 585            .unwrap_or_else(|| vec![].into())
 586    }
 587
 588    pub fn context_for_project_with_buffers<'a>(
 589        &'a self,
 590        project: &Entity<Project>,
 591        cx: &'a App,
 592    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 593        self.projects
 594            .get(&project.entity_id())
 595            .map(|project| project.context.read(cx).related_files_with_buffers())
 596    }
 597
 598    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 599        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 600            self.user_store.read(cx).edit_prediction_usage()
 601        } else {
 602            None
 603        }
 604    }
 605
 606    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 607        self.get_or_init_project(project, cx);
 608    }
 609
 610    pub fn register_buffer(
 611        &mut self,
 612        buffer: &Entity<Buffer>,
 613        project: &Entity<Project>,
 614        cx: &mut Context<Self>,
 615    ) {
 616        let project_state = self.get_or_init_project(project, cx);
 617        Self::register_buffer_impl(project_state, buffer, project, cx);
 618    }
 619
 620    fn get_or_init_project(
 621        &mut self,
 622        project: &Entity<Project>,
 623        cx: &mut Context<Self>,
 624    ) -> &mut ProjectState {
 625        let entity_id = project.entity_id();
 626        self.projects
 627            .entry(entity_id)
 628            .or_insert_with(|| ProjectState {
 629                context: {
 630                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 631                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 632                        this.handle_excerpt_store_event(entity_id, event);
 633                    })
 634                    .detach();
 635                    related_excerpt_store
 636                },
 637                events: VecDeque::new(),
 638                last_event: None,
 639                recent_paths: VecDeque::new(),
 640                debug_tx: None,
 641                registered_buffers: HashMap::default(),
 642                current_prediction: None,
 643                cancelled_predictions: HashSet::default(),
 644                pending_predictions: ArrayVec::new(),
 645                next_pending_prediction_id: 0,
 646                last_prediction_refresh: None,
 647                license_detection_watchers: HashMap::default(),
 648                _subscription: cx.subscribe(&project, Self::handle_project_event),
 649            })
 650    }
 651
 652    pub fn remove_project(&mut self, project: &Entity<Project>) {
 653        self.projects.remove(&project.entity_id());
 654    }
 655
 656    fn handle_excerpt_store_event(
 657        &mut self,
 658        project_entity_id: EntityId,
 659        event: &RelatedExcerptStoreEvent,
 660    ) {
 661        if let Some(project_state) = self.projects.get(&project_entity_id) {
 662            if let Some(debug_tx) = project_state.debug_tx.clone() {
 663                match event {
 664                    RelatedExcerptStoreEvent::StartedRefresh => {
 665                        debug_tx
 666                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 667                                ContextRetrievalStartedDebugEvent {
 668                                    project_entity_id: project_entity_id,
 669                                    timestamp: Instant::now(),
 670                                    search_prompt: String::new(),
 671                                },
 672                            ))
 673                            .ok();
 674                    }
 675                    RelatedExcerptStoreEvent::FinishedRefresh {
 676                        cache_hit_count,
 677                        cache_miss_count,
 678                        mean_definition_latency,
 679                        max_definition_latency,
 680                    } => {
 681                        debug_tx
 682                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 683                                ContextRetrievalFinishedDebugEvent {
 684                                    project_entity_id: project_entity_id,
 685                                    timestamp: Instant::now(),
 686                                    metadata: vec![
 687                                        (
 688                                            "Cache Hits",
 689                                            format!(
 690                                                "{}/{}",
 691                                                cache_hit_count,
 692                                                cache_hit_count + cache_miss_count
 693                                            )
 694                                            .into(),
 695                                        ),
 696                                        (
 697                                            "Max LSP Time",
 698                                            format!("{} ms", max_definition_latency.as_millis())
 699                                                .into(),
 700                                        ),
 701                                        (
 702                                            "Mean LSP Time",
 703                                            format!("{} ms", mean_definition_latency.as_millis())
 704                                                .into(),
 705                                        ),
 706                                    ],
 707                                },
 708                            ))
 709                            .ok();
 710                    }
 711                }
 712            }
 713        }
 714    }
 715
 716    pub fn debug_info(
 717        &mut self,
 718        project: &Entity<Project>,
 719        cx: &mut Context<Self>,
 720    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 721        let project_state = self.get_or_init_project(project, cx);
 722        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 723        project_state.debug_tx = Some(debug_watch_tx);
 724        debug_watch_rx
 725    }
 726
 727    fn handle_project_event(
 728        &mut self,
 729        project: Entity<Project>,
 730        event: &project::Event,
 731        cx: &mut Context<Self>,
 732    ) {
 733        // TODO [zeta2] init with recent paths
 734        match event {
 735            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 736                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 737                    return;
 738                };
 739                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 740                if let Some(path) = path {
 741                    if let Some(ix) = project_state
 742                        .recent_paths
 743                        .iter()
 744                        .position(|probe| probe == &path)
 745                    {
 746                        project_state.recent_paths.remove(ix);
 747                    }
 748                    project_state.recent_paths.push_front(path);
 749                }
 750            }
 751            project::Event::DiagnosticsUpdated { .. } => {
 752                if cx.has_flag::<Zeta2FeatureFlag>() {
 753                    self.refresh_prediction_from_diagnostics(project, cx);
 754                }
 755            }
 756            _ => (),
 757        }
 758    }
 759
 760    fn register_buffer_impl<'a>(
 761        project_state: &'a mut ProjectState,
 762        buffer: &Entity<Buffer>,
 763        project: &Entity<Project>,
 764        cx: &mut Context<Self>,
 765    ) -> &'a mut RegisteredBuffer {
 766        let buffer_id = buffer.entity_id();
 767
 768        if let Some(file) = buffer.read(cx).file() {
 769            let worktree_id = file.worktree_id(cx);
 770            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 771                project_state
 772                    .license_detection_watchers
 773                    .entry(worktree_id)
 774                    .or_insert_with(|| {
 775                        let project_entity_id = project.entity_id();
 776                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 777                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 778                            else {
 779                                return;
 780                            };
 781                            project_state
 782                                .license_detection_watchers
 783                                .remove(&worktree_id);
 784                        })
 785                        .detach();
 786                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 787                    });
 788            }
 789        }
 790
 791        match project_state.registered_buffers.entry(buffer_id) {
 792            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 793            hash_map::Entry::Vacant(entry) => {
 794                let snapshot = buffer.read(cx).snapshot();
 795                let project_entity_id = project.entity_id();
 796                entry.insert(RegisteredBuffer {
 797                    snapshot,
 798                    _subscriptions: [
 799                        cx.subscribe(buffer, {
 800                            let project = project.downgrade();
 801                            move |this, buffer, event, cx| {
 802                                if let language::BufferEvent::Edited = event
 803                                    && let Some(project) = project.upgrade()
 804                                {
 805                                    this.report_changes_for_buffer(&buffer, &project, cx);
 806                                }
 807                            }
 808                        }),
 809                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 810                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 811                            else {
 812                                return;
 813                            };
 814                            project_state.registered_buffers.remove(&buffer_id);
 815                        }),
 816                    ],
 817                })
 818            }
 819        }
 820    }
 821
 822    fn report_changes_for_buffer(
 823        &mut self,
 824        buffer: &Entity<Buffer>,
 825        project: &Entity<Project>,
 826        cx: &mut Context<Self>,
 827    ) {
 828        let project_state = self.get_or_init_project(project, cx);
 829        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 830
 831        let new_snapshot = buffer.read(cx).snapshot();
 832        if new_snapshot.version == registered_buffer.snapshot.version {
 833            return;
 834        }
 835
 836        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 837        let end_edit_anchor = new_snapshot
 838            .anchored_edits_since::<Point>(&old_snapshot.version)
 839            .last()
 840            .map(|(_, range)| range.end);
 841        let events = &mut project_state.events;
 842
 843        if let Some(LastEvent {
 844            new_snapshot: last_new_snapshot,
 845            end_edit_anchor: last_end_edit_anchor,
 846            ..
 847        }) = project_state.last_event.as_mut()
 848        {
 849            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 850                == last_new_snapshot.remote_id()
 851                && old_snapshot.version == last_new_snapshot.version;
 852
 853            let should_coalesce = is_next_snapshot_of_same_buffer
 854                && end_edit_anchor
 855                    .as_ref()
 856                    .zip(last_end_edit_anchor.as_ref())
 857                    .is_some_and(|(a, b)| {
 858                        let a = a.to_point(&new_snapshot);
 859                        let b = b.to_point(&new_snapshot);
 860                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 861                    });
 862
 863            if should_coalesce {
 864                *last_end_edit_anchor = end_edit_anchor;
 865                *last_new_snapshot = new_snapshot;
 866                return;
 867            }
 868        }
 869
 870        if events.len() + 1 >= EVENT_COUNT_MAX {
 871            events.pop_front();
 872        }
 873
 874        if let Some(event) = project_state.last_event.take() {
 875            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 876        }
 877
 878        project_state.last_event = Some(LastEvent {
 879            old_snapshot,
 880            new_snapshot,
 881            end_edit_anchor,
 882        });
 883    }
 884
 885    fn current_prediction_for_buffer(
 886        &self,
 887        buffer: &Entity<Buffer>,
 888        project: &Entity<Project>,
 889        cx: &App,
 890    ) -> Option<BufferEditPrediction<'_>> {
 891        let project_state = self.projects.get(&project.entity_id())?;
 892
 893        let CurrentEditPrediction {
 894            requested_by,
 895            prediction,
 896            ..
 897        } = project_state.current_prediction.as_ref()?;
 898
 899        if prediction.targets_buffer(buffer.read(cx)) {
 900            Some(BufferEditPrediction::Local { prediction })
 901        } else {
 902            let show_jump = match requested_by {
 903                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 904                    requested_by_buffer_id == &buffer.entity_id()
 905                }
 906                PredictionRequestedBy::DiagnosticsUpdate => true,
 907            };
 908
 909            if show_jump {
 910                Some(BufferEditPrediction::Jump { prediction })
 911            } else {
 912                None
 913            }
 914        }
 915    }
 916
 917    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 918        match self.edit_prediction_model {
 919            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
 920            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
 921        }
 922
 923        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 924            return;
 925        };
 926
 927        let Some(prediction) = project_state.current_prediction.take() else {
 928            return;
 929        };
 930        let request_id = prediction.prediction.id.to_string();
 931        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 932            project_state.cancel_pending_prediction(pending_prediction, cx);
 933        }
 934
 935        let client = self.client.clone();
 936        let llm_token = self.llm_token.clone();
 937        let app_version = AppVersion::global(cx);
 938        cx.spawn(async move |this, cx| {
 939            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 940                http_client::Url::parse(&predict_edits_url)?
 941            } else {
 942                client
 943                    .http_client()
 944                    .build_zed_llm_url("/predict_edits/accept", &[])?
 945            };
 946
 947            let response = cx
 948                .background_spawn(Self::send_api_request::<()>(
 949                    move |builder| {
 950                        let req = builder.uri(url.as_ref()).body(
 951                            serde_json::to_string(&AcceptEditPredictionBody {
 952                                request_id: request_id.clone(),
 953                            })?
 954                            .into(),
 955                        );
 956                        Ok(req?)
 957                    },
 958                    client,
 959                    llm_token,
 960                    app_version,
 961                ))
 962                .await;
 963
 964            Self::handle_api_response(&this, response, cx)?;
 965            anyhow::Ok(())
 966        })
 967        .detach_and_log_err(cx);
 968    }
 969
 970    async fn handle_rejected_predictions(
 971        rx: UnboundedReceiver<EditPredictionRejection>,
 972        client: Arc<Client>,
 973        llm_token: LlmApiToken,
 974        app_version: Version,
 975        background_executor: BackgroundExecutor,
 976    ) {
 977        let mut rx = std::pin::pin!(rx.peekable());
 978        let mut batched = Vec::new();
 979
 980        while let Some(rejection) = rx.next().await {
 981            batched.push(rejection);
 982
 983            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
 984                select_biased! {
 985                    next = rx.as_mut().peek().fuse() => {
 986                        if next.is_some() {
 987                            continue;
 988                        }
 989                    }
 990                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
 991                }
 992            }
 993
 994            let url = client
 995                .http_client()
 996                .build_zed_llm_url("/predict_edits/reject", &[])
 997                .unwrap();
 998
 999            let flush_count = batched
1000                .len()
1001                // in case items have accumulated after failure
1002                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1003            let start = batched.len() - flush_count;
1004
1005            let body = RejectEditPredictionsBodyRef {
1006                rejections: &batched[start..],
1007            };
1008
1009            let result = Self::send_api_request::<()>(
1010                |builder| {
1011                    let req = builder
1012                        .uri(url.as_ref())
1013                        .body(serde_json::to_string(&body)?.into());
1014                    anyhow::Ok(req?)
1015                },
1016                client.clone(),
1017                llm_token.clone(),
1018                app_version.clone(),
1019            )
1020            .await;
1021
1022            if result.log_err().is_some() {
1023                batched.drain(start..);
1024            }
1025        }
1026    }
1027
1028    fn reject_current_prediction(
1029        &mut self,
1030        reason: EditPredictionRejectReason,
1031        project: &Entity<Project>,
1032    ) {
1033        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1034            project_state.pending_predictions.clear();
1035            if let Some(prediction) = project_state.current_prediction.take() {
1036                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1037            }
1038        };
1039    }
1040
1041    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1042        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1043            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1044                if !current_prediction.was_shown {
1045                    current_prediction.was_shown = true;
1046                    self.shown_predictions
1047                        .push_front(current_prediction.prediction.clone());
1048                    if self.shown_predictions.len() > 50 {
1049                        let completion = self.shown_predictions.pop_back().unwrap();
1050                        self.rated_predictions.remove(&completion.id);
1051                    }
1052                }
1053            }
1054        }
1055    }
1056
1057    fn reject_prediction(
1058        &mut self,
1059        prediction_id: EditPredictionId,
1060        reason: EditPredictionRejectReason,
1061        was_shown: bool,
1062    ) {
1063        match self.edit_prediction_model {
1064            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1065            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1066        }
1067
1068        self.reject_predictions_tx
1069            .unbounded_send(EditPredictionRejection {
1070                request_id: prediction_id.to_string(),
1071                reason,
1072                was_shown,
1073            })
1074            .log_err();
1075    }
1076
1077    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1078        self.projects
1079            .get(&project.entity_id())
1080            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1081    }
1082
1083    pub fn refresh_prediction_from_buffer(
1084        &mut self,
1085        project: Entity<Project>,
1086        buffer: Entity<Buffer>,
1087        position: language::Anchor,
1088        cx: &mut Context<Self>,
1089    ) {
1090        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1091            let Some(request_task) = this
1092                .update(cx, |this, cx| {
1093                    this.request_prediction(
1094                        &project,
1095                        &buffer,
1096                        position,
1097                        PredictEditsRequestTrigger::Other,
1098                        cx,
1099                    )
1100                })
1101                .log_err()
1102            else {
1103                return Task::ready(anyhow::Ok(None));
1104            };
1105
1106            cx.spawn(async move |_cx| {
1107                request_task.await.map(|prediction_result| {
1108                    prediction_result.map(|prediction_result| {
1109                        (
1110                            prediction_result,
1111                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1112                        )
1113                    })
1114                })
1115            })
1116        })
1117    }
1118
1119    pub fn refresh_prediction_from_diagnostics(
1120        &mut self,
1121        project: Entity<Project>,
1122        cx: &mut Context<Self>,
1123    ) {
1124        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1125            return;
1126        };
1127
1128        // Prefer predictions from buffer
1129        if project_state.current_prediction.is_some() {
1130            return;
1131        };
1132
1133        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1134            let Some(open_buffer_task) = project
1135                .update(cx, |project, cx| {
1136                    project
1137                        .active_entry()
1138                        .and_then(|entry| project.path_for_entry(entry, cx))
1139                        .map(|path| project.open_buffer(path, cx))
1140                })
1141                .log_err()
1142                .flatten()
1143            else {
1144                return Task::ready(anyhow::Ok(None));
1145            };
1146
1147            cx.spawn(async move |cx| {
1148                let active_buffer = open_buffer_task.await?;
1149                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1150
1151                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1152                    active_buffer,
1153                    &snapshot,
1154                    Default::default(),
1155                    Default::default(),
1156                    &project,
1157                    cx,
1158                )
1159                .await?
1160                else {
1161                    return anyhow::Ok(None);
1162                };
1163
1164                let Some(prediction_result) = this
1165                    .update(cx, |this, cx| {
1166                        this.request_prediction(
1167                            &project,
1168                            &jump_buffer,
1169                            jump_position,
1170                            PredictEditsRequestTrigger::Diagnostics,
1171                            cx,
1172                        )
1173                    })?
1174                    .await?
1175                else {
1176                    return anyhow::Ok(None);
1177                };
1178
1179                this.update(cx, |this, cx| {
1180                    Some((
1181                        if this
1182                            .get_or_init_project(&project, cx)
1183                            .current_prediction
1184                            .is_none()
1185                        {
1186                            prediction_result
1187                        } else {
1188                            EditPredictionResult {
1189                                id: prediction_result.id,
1190                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1191                            }
1192                        },
1193                        PredictionRequestedBy::DiagnosticsUpdate,
1194                    ))
1195                })
1196            })
1197        });
1198    }
1199
1200    #[cfg(not(test))]
1201    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1202    #[cfg(test)]
1203    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1204
1205    fn queue_prediction_refresh(
1206        &mut self,
1207        project: Entity<Project>,
1208        throttle_entity: EntityId,
1209        cx: &mut Context<Self>,
1210        do_refresh: impl FnOnce(
1211            WeakEntity<Self>,
1212            &mut AsyncApp,
1213        )
1214            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1215        + 'static,
1216    ) {
1217        let project_state = self.get_or_init_project(&project, cx);
1218        let pending_prediction_id = project_state.next_pending_prediction_id;
1219        project_state.next_pending_prediction_id += 1;
1220        let last_request = project_state.last_prediction_refresh;
1221
1222        let task = cx.spawn(async move |this, cx| {
1223            if let Some((last_entity, last_timestamp)) = last_request
1224                && throttle_entity == last_entity
1225                && let Some(timeout) =
1226                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1227            {
1228                cx.background_executor().timer(timeout).await;
1229            }
1230
1231            // If this task was cancelled before the throttle timeout expired,
1232            // do not perform a request.
1233            let mut is_cancelled = true;
1234            this.update(cx, |this, cx| {
1235                let project_state = this.get_or_init_project(&project, cx);
1236                if !project_state
1237                    .cancelled_predictions
1238                    .remove(&pending_prediction_id)
1239                {
1240                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1241                    is_cancelled = false;
1242                }
1243            })
1244            .ok();
1245            if is_cancelled {
1246                return None;
1247            }
1248
1249            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1250            let new_prediction_id = new_prediction_result
1251                .as_ref()
1252                .map(|(prediction, _)| prediction.id.clone());
1253
1254            // When a prediction completes, remove it from the pending list, and cancel
1255            // any pending predictions that were enqueued before it.
1256            this.update(cx, |this, cx| {
1257                let project_state = this.get_or_init_project(&project, cx);
1258
1259                let is_cancelled = project_state
1260                    .cancelled_predictions
1261                    .remove(&pending_prediction_id);
1262
1263                let new_current_prediction = if !is_cancelled
1264                    && let Some((prediction_result, requested_by)) = new_prediction_result
1265                {
1266                    match prediction_result.prediction {
1267                        Ok(prediction) => {
1268                            let new_prediction = CurrentEditPrediction {
1269                                requested_by,
1270                                prediction,
1271                                was_shown: false,
1272                            };
1273
1274                            if let Some(current_prediction) =
1275                                project_state.current_prediction.as_ref()
1276                            {
1277                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1278                                {
1279                                    this.reject_current_prediction(
1280                                        EditPredictionRejectReason::Replaced,
1281                                        &project,
1282                                    );
1283
1284                                    Some(new_prediction)
1285                                } else {
1286                                    this.reject_prediction(
1287                                        new_prediction.prediction.id,
1288                                        EditPredictionRejectReason::CurrentPreferred,
1289                                        false,
1290                                    );
1291                                    None
1292                                }
1293                            } else {
1294                                Some(new_prediction)
1295                            }
1296                        }
1297                        Err(reject_reason) => {
1298                            this.reject_prediction(prediction_result.id, reject_reason, false);
1299                            None
1300                        }
1301                    }
1302                } else {
1303                    None
1304                };
1305
1306                let project_state = this.get_or_init_project(&project, cx);
1307
1308                if let Some(new_prediction) = new_current_prediction {
1309                    project_state.current_prediction = Some(new_prediction);
1310                }
1311
1312                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1313                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1314                    if pending_prediction.id == pending_prediction_id {
1315                        pending_predictions.remove(ix);
1316                        for pending_prediction in pending_predictions.drain(0..ix) {
1317                            project_state.cancel_pending_prediction(pending_prediction, cx)
1318                        }
1319                        break;
1320                    }
1321                }
1322                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1323                cx.notify();
1324            })
1325            .ok();
1326
1327            new_prediction_id
1328        });
1329
1330        if project_state.pending_predictions.len() <= 1 {
1331            project_state.pending_predictions.push(PendingPrediction {
1332                id: pending_prediction_id,
1333                task,
1334            });
1335        } else if project_state.pending_predictions.len() == 2 {
1336            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1337            project_state.pending_predictions.push(PendingPrediction {
1338                id: pending_prediction_id,
1339                task,
1340            });
1341            project_state.cancel_pending_prediction(pending_prediction, cx);
1342        }
1343    }
1344
1345    pub fn request_prediction(
1346        &mut self,
1347        project: &Entity<Project>,
1348        active_buffer: &Entity<Buffer>,
1349        position: language::Anchor,
1350        trigger: PredictEditsRequestTrigger,
1351        cx: &mut Context<Self>,
1352    ) -> Task<Result<Option<EditPredictionResult>>> {
1353        self.request_prediction_internal(
1354            project.clone(),
1355            active_buffer.clone(),
1356            position,
1357            trigger,
1358            cx.has_flag::<Zeta2FeatureFlag>(),
1359            cx,
1360        )
1361    }
1362
1363    fn request_prediction_internal(
1364        &mut self,
1365        project: Entity<Project>,
1366        active_buffer: Entity<Buffer>,
1367        position: language::Anchor,
1368        trigger: PredictEditsRequestTrigger,
1369        allow_jump: bool,
1370        cx: &mut Context<Self>,
1371    ) -> Task<Result<Option<EditPredictionResult>>> {
1372        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1373
1374        self.get_or_init_project(&project, cx);
1375        let project_state = self.projects.get(&project.entity_id()).unwrap();
1376        let events = project_state.events(cx);
1377        let has_events = !events.is_empty();
1378        let debug_tx = project_state.debug_tx.clone();
1379
1380        let snapshot = active_buffer.read(cx).snapshot();
1381        let cursor_point = position.to_point(&snapshot);
1382        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1383        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1384        let diagnostic_search_range =
1385            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1386
1387        let related_files = if self.use_context {
1388            self.context_for_project(&project, cx)
1389        } else {
1390            Vec::new().into()
1391        };
1392
1393        let inputs = EditPredictionModelInput {
1394            project: project.clone(),
1395            buffer: active_buffer.clone(),
1396            snapshot: snapshot.clone(),
1397            position,
1398            events,
1399            related_files,
1400            recent_paths: project_state.recent_paths.clone(),
1401            trigger,
1402            diagnostic_search_range: diagnostic_search_range.clone(),
1403            debug_tx,
1404        };
1405
1406        let task = match self.edit_prediction_model {
1407            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1408            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1409            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1410            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1411        };
1412
1413        cx.spawn(async move |this, cx| {
1414            let prediction = task.await?;
1415
1416            if prediction.is_none() && allow_jump {
1417                let cursor_point = position.to_point(&snapshot);
1418                if has_events
1419                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1420                        active_buffer.clone(),
1421                        &snapshot,
1422                        diagnostic_search_range,
1423                        cursor_point,
1424                        &project,
1425                        cx,
1426                    )
1427                    .await?
1428                {
1429                    return this
1430                        .update(cx, |this, cx| {
1431                            this.request_prediction_internal(
1432                                project,
1433                                jump_buffer,
1434                                jump_position,
1435                                trigger,
1436                                false,
1437                                cx,
1438                            )
1439                        })?
1440                        .await;
1441                }
1442
1443                return anyhow::Ok(None);
1444            }
1445
1446            Ok(prediction)
1447        })
1448    }
1449
1450    async fn next_diagnostic_location(
1451        active_buffer: Entity<Buffer>,
1452        active_buffer_snapshot: &BufferSnapshot,
1453        active_buffer_diagnostic_search_range: Range<Point>,
1454        active_buffer_cursor_point: Point,
1455        project: &Entity<Project>,
1456        cx: &mut AsyncApp,
1457    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1458        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1459        let mut jump_location = active_buffer_snapshot
1460            .diagnostic_groups(None)
1461            .into_iter()
1462            .filter_map(|(_, group)| {
1463                let range = &group.entries[group.primary_ix]
1464                    .range
1465                    .to_point(&active_buffer_snapshot);
1466                if range.overlaps(&active_buffer_diagnostic_search_range) {
1467                    None
1468                } else {
1469                    Some(range.start)
1470                }
1471            })
1472            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1473            .map(|position| {
1474                (
1475                    active_buffer.clone(),
1476                    active_buffer_snapshot.anchor_before(position),
1477                )
1478            });
1479
1480        if jump_location.is_none() {
1481            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1482                let file = buffer.file()?;
1483
1484                Some(ProjectPath {
1485                    worktree_id: file.worktree_id(cx),
1486                    path: file.path().clone(),
1487                })
1488            })?;
1489
1490            let buffer_task = project.update(cx, |project, cx| {
1491                let (path, _, _) = project
1492                    .diagnostic_summaries(false, cx)
1493                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1494                    .max_by_key(|(path, _, _)| {
1495                        // find the buffer with errors that shares most parent directories
1496                        path.path
1497                            .components()
1498                            .zip(
1499                                active_buffer_path
1500                                    .as_ref()
1501                                    .map(|p| p.path.components())
1502                                    .unwrap_or_default(),
1503                            )
1504                            .take_while(|(a, b)| a == b)
1505                            .count()
1506                    })?;
1507
1508                Some(project.open_buffer(path, cx))
1509            })?;
1510
1511            if let Some(buffer_task) = buffer_task {
1512                let closest_buffer = buffer_task.await?;
1513
1514                jump_location = closest_buffer
1515                    .read_with(cx, |buffer, _cx| {
1516                        buffer
1517                            .buffer_diagnostics(None)
1518                            .into_iter()
1519                            .min_by_key(|entry| entry.diagnostic.severity)
1520                            .map(|entry| entry.range.start)
1521                    })?
1522                    .map(|position| (closest_buffer, position));
1523            }
1524        }
1525
1526        anyhow::Ok(jump_location)
1527    }
1528
1529    async fn send_raw_llm_request(
1530        request: open_ai::Request,
1531        client: Arc<Client>,
1532        llm_token: LlmApiToken,
1533        app_version: Version,
1534        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1535        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1536    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1537        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1538            http_client::Url::parse(&predict_edits_url)?
1539        } else {
1540            client
1541                .http_client()
1542                .build_zed_llm_url("/predict_edits/raw", &[])?
1543        };
1544
1545        #[cfg(feature = "eval-support")]
1546        let cache_key = if let Some(cache) = eval_cache {
1547            use collections::FxHasher;
1548            use std::hash::{Hash, Hasher};
1549
1550            let mut hasher = FxHasher::default();
1551            url.hash(&mut hasher);
1552            let request_str = serde_json::to_string_pretty(&request)?;
1553            request_str.hash(&mut hasher);
1554            let hash = hasher.finish();
1555
1556            let key = (eval_cache_kind, hash);
1557            if let Some(response_str) = cache.read(key) {
1558                return Ok((serde_json::from_str(&response_str)?, None));
1559            }
1560
1561            Some((cache, request_str, key))
1562        } else {
1563            None
1564        };
1565
1566        let (response, usage) = Self::send_api_request(
1567            |builder| {
1568                let req = builder
1569                    .uri(url.as_ref())
1570                    .body(serde_json::to_string(&request)?.into());
1571                Ok(req?)
1572            },
1573            client,
1574            llm_token,
1575            app_version,
1576        )
1577        .await?;
1578
1579        #[cfg(feature = "eval-support")]
1580        if let Some((cache, request, key)) = cache_key {
1581            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1582        }
1583
1584        Ok((response, usage))
1585    }
1586
1587    fn handle_api_response<T>(
1588        this: &WeakEntity<Self>,
1589        response: Result<(T, Option<EditPredictionUsage>)>,
1590        cx: &mut gpui::AsyncApp,
1591    ) -> Result<T> {
1592        match response {
1593            Ok((data, usage)) => {
1594                if let Some(usage) = usage {
1595                    this.update(cx, |this, cx| {
1596                        this.user_store.update(cx, |user_store, cx| {
1597                            user_store.update_edit_prediction_usage(usage, cx);
1598                        });
1599                    })
1600                    .ok();
1601                }
1602                Ok(data)
1603            }
1604            Err(err) => {
1605                if err.is::<ZedUpdateRequiredError>() {
1606                    cx.update(|cx| {
1607                        this.update(cx, |this, _cx| {
1608                            this.update_required = true;
1609                        })
1610                        .ok();
1611
1612                        let error_message: SharedString = err.to_string().into();
1613                        show_app_notification(
1614                            NotificationId::unique::<ZedUpdateRequiredError>(),
1615                            cx,
1616                            move |cx| {
1617                                cx.new(|cx| {
1618                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1619                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1620                                })
1621                            },
1622                        );
1623                    })
1624                    .ok();
1625                }
1626                Err(err)
1627            }
1628        }
1629    }
1630
1631    async fn send_api_request<Res>(
1632        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1633        client: Arc<Client>,
1634        llm_token: LlmApiToken,
1635        app_version: Version,
1636    ) -> Result<(Res, Option<EditPredictionUsage>)>
1637    where
1638        Res: DeserializeOwned,
1639    {
1640        let http_client = client.http_client();
1641        let mut token = llm_token.acquire(&client).await?;
1642        let mut did_retry = false;
1643
1644        loop {
1645            let request_builder = http_client::Request::builder().method(Method::POST);
1646
1647            let request = build(
1648                request_builder
1649                    .header("Content-Type", "application/json")
1650                    .header("Authorization", format!("Bearer {}", token))
1651                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1652            )?;
1653
1654            let mut response = http_client.send(request).await?;
1655
1656            if let Some(minimum_required_version) = response
1657                .headers()
1658                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1659                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1660            {
1661                anyhow::ensure!(
1662                    app_version >= minimum_required_version,
1663                    ZedUpdateRequiredError {
1664                        minimum_version: minimum_required_version
1665                    }
1666                );
1667            }
1668
1669            if response.status().is_success() {
1670                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1671
1672                let mut body = Vec::new();
1673                response.body_mut().read_to_end(&mut body).await?;
1674                return Ok((serde_json::from_slice(&body)?, usage));
1675            } else if !did_retry
1676                && response
1677                    .headers()
1678                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1679                    .is_some()
1680            {
1681                did_retry = true;
1682                token = llm_token.refresh(&client).await?;
1683            } else {
1684                let mut body = String::new();
1685                response.body_mut().read_to_string(&mut body).await?;
1686                anyhow::bail!(
1687                    "Request failed with status: {:?}\nBody: {}",
1688                    response.status(),
1689                    body
1690                );
1691            }
1692        }
1693    }
1694
1695    pub fn refresh_context(
1696        &mut self,
1697        project: &Entity<Project>,
1698        buffer: &Entity<language::Buffer>,
1699        cursor_position: language::Anchor,
1700        cx: &mut Context<Self>,
1701    ) {
1702        if self.use_context {
1703            self.get_or_init_project(project, cx)
1704                .context
1705                .update(cx, |store, cx| {
1706                    store.refresh(buffer.clone(), cursor_position, cx);
1707                });
1708        }
1709    }
1710
1711    #[cfg(feature = "eval-support")]
1712    pub fn set_context_for_buffer(
1713        &mut self,
1714        project: &Entity<Project>,
1715        related_files: Vec<RelatedFile>,
1716        cx: &mut Context<Self>,
1717    ) {
1718        self.get_or_init_project(project, cx)
1719            .context
1720            .update(cx, |store, _| {
1721                store.set_related_files(related_files);
1722            });
1723    }
1724
1725    fn is_file_open_source(
1726        &self,
1727        project: &Entity<Project>,
1728        file: &Arc<dyn File>,
1729        cx: &App,
1730    ) -> bool {
1731        if !file.is_local() || file.is_private() {
1732            return false;
1733        }
1734        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1735            return false;
1736        };
1737        project_state
1738            .license_detection_watchers
1739            .get(&file.worktree_id(cx))
1740            .as_ref()
1741            .is_some_and(|watcher| watcher.is_project_open_source())
1742    }
1743
1744    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1745        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1746    }
1747
1748    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1749        if !self.data_collection_choice.is_enabled() {
1750            return false;
1751        }
1752        events.iter().all(|event| {
1753            matches!(
1754                event.as_ref(),
1755                zeta_prompt::Event::BufferChange {
1756                    in_open_source_repo: true,
1757                    ..
1758                }
1759            )
1760        })
1761    }
1762
1763    fn load_data_collection_choice() -> DataCollectionChoice {
1764        let choice = KEY_VALUE_STORE
1765            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1766            .log_err()
1767            .flatten();
1768
1769        match choice.as_deref() {
1770            Some("true") => DataCollectionChoice::Enabled,
1771            Some("false") => DataCollectionChoice::Disabled,
1772            Some(_) => {
1773                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1774                DataCollectionChoice::NotAnswered
1775            }
1776            None => DataCollectionChoice::NotAnswered,
1777        }
1778    }
1779
1780    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1781        self.data_collection_choice = self.data_collection_choice.toggle();
1782        let new_choice = self.data_collection_choice;
1783        db::write_and_log(cx, move || {
1784            KEY_VALUE_STORE.write_kvp(
1785                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1786                new_choice.is_enabled().to_string(),
1787            )
1788        });
1789    }
1790
1791    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1792        self.shown_predictions.iter()
1793    }
1794
1795    pub fn shown_completions_len(&self) -> usize {
1796        self.shown_predictions.len()
1797    }
1798
1799    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1800        self.rated_predictions.contains(id)
1801    }
1802
1803    pub fn rate_prediction(
1804        &mut self,
1805        prediction: &EditPrediction,
1806        rating: EditPredictionRating,
1807        feedback: String,
1808        cx: &mut Context<Self>,
1809    ) {
1810        self.rated_predictions.insert(prediction.id.clone());
1811        telemetry::event!(
1812            "Edit Prediction Rated",
1813            rating,
1814            inputs = prediction.inputs,
1815            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1816            feedback
1817        );
1818        self.client.telemetry().flush_events().detach();
1819        cx.notify();
1820    }
1821
1822    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1823        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1824            && all_language_settings(None, cx).edit_predictions.use_context;
1825    }
1826}
1827
1828#[derive(Error, Debug)]
1829#[error(
1830    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1831)]
1832pub struct ZedUpdateRequiredError {
1833    minimum_version: Version,
1834}
1835
1836#[cfg(feature = "eval-support")]
1837pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1838
1839#[cfg(feature = "eval-support")]
1840#[derive(Debug, Clone, Copy, PartialEq)]
1841pub enum EvalCacheEntryKind {
1842    Context,
1843    Search,
1844    Prediction,
1845}
1846
1847#[cfg(feature = "eval-support")]
1848impl std::fmt::Display for EvalCacheEntryKind {
1849    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1850        match self {
1851            EvalCacheEntryKind::Search => write!(f, "search"),
1852            EvalCacheEntryKind::Context => write!(f, "context"),
1853            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1854        }
1855    }
1856}
1857
1858#[cfg(feature = "eval-support")]
1859pub trait EvalCache: Send + Sync {
1860    fn read(&self, key: EvalCacheKey) -> Option<String>;
1861    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1862}
1863
1864#[derive(Debug, Clone, Copy)]
1865pub enum DataCollectionChoice {
1866    NotAnswered,
1867    Enabled,
1868    Disabled,
1869}
1870
1871impl DataCollectionChoice {
1872    pub fn is_enabled(self) -> bool {
1873        match self {
1874            Self::Enabled => true,
1875            Self::NotAnswered | Self::Disabled => false,
1876        }
1877    }
1878
1879    pub fn is_answered(self) -> bool {
1880        match self {
1881            Self::Enabled | Self::Disabled => true,
1882            Self::NotAnswered => false,
1883        }
1884    }
1885
1886    #[must_use]
1887    pub fn toggle(&self) -> DataCollectionChoice {
1888        match self {
1889            Self::Enabled => Self::Disabled,
1890            Self::Disabled => Self::Enabled,
1891            Self::NotAnswered => Self::Enabled,
1892        }
1893    }
1894}
1895
1896impl From<bool> for DataCollectionChoice {
1897    fn from(value: bool) -> Self {
1898        match value {
1899            true => DataCollectionChoice::Enabled,
1900            false => DataCollectionChoice::Disabled,
1901        }
1902    }
1903}
1904
1905struct ZedPredictUpsell;
1906
1907impl Dismissable for ZedPredictUpsell {
1908    const KEY: &'static str = "dismissed-edit-predict-upsell";
1909
1910    fn dismissed() -> bool {
1911        // To make this backwards compatible with older versions of Zed, we
1912        // check if the user has seen the previous Edit Prediction Onboarding
1913        // before, by checking the data collection choice which was written to
1914        // the database once the user clicked on "Accept and Enable"
1915        if KEY_VALUE_STORE
1916            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1917            .log_err()
1918            .is_some_and(|s| s.is_some())
1919        {
1920            return true;
1921        }
1922
1923        KEY_VALUE_STORE
1924            .read_kvp(Self::KEY)
1925            .log_err()
1926            .is_some_and(|s| s.is_some())
1927    }
1928}
1929
1930pub fn should_show_upsell_modal() -> bool {
1931    !ZedPredictUpsell::dismissed()
1932}
1933
1934pub fn init(cx: &mut App) {
1935    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1936        workspace.register_action(
1937            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1938                ZedPredictModal::toggle(
1939                    workspace,
1940                    workspace.user_store().clone(),
1941                    workspace.client().clone(),
1942                    window,
1943                    cx,
1944                )
1945            },
1946        );
1947
1948        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
1949            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
1950                settings
1951                    .project
1952                    .all_languages
1953                    .features
1954                    .get_or_insert_default()
1955                    .edit_prediction_provider = Some(EditPredictionProvider::None)
1956            });
1957        });
1958    })
1959    .detach();
1960}