edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  12use collections::{HashMap, HashSet};
  13use db::kvp::{Dismissable, KEY_VALUE_STORE};
  14use edit_prediction_context::EditPredictionExcerptOptions;
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::{
  20        mpsc::{self, UnboundedReceiver},
  21        oneshot,
  22    },
  23    select_biased,
  24};
  25use gpui::BackgroundExecutor;
  26use gpui::{
  27    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  28    http_client::{self, AsyncBody, Method},
  29    prelude::*,
  30};
  31use language::language_settings::all_language_settings;
  32use language::{Anchor, Buffer, File, Point, ToPoint};
  33use language::{BufferSnapshot, OffsetRangeExt};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use project::{Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  40use std::collections::{VecDeque, hash_map};
  41use workspace::Workspace;
  42
  43use std::ops::Range;
  44use std::path::Path;
  45use std::rc::Rc;
  46use std::str::FromStr as _;
  47use std::sync::{Arc, LazyLock};
  48use std::time::{Duration, Instant};
  49use std::{env, mem};
  50use thiserror::Error;
  51use util::{RangeExt as _, ResultExt as _};
  52use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  53
  54mod cursor_excerpt;
  55mod license_detection;
  56pub mod mercury;
  57mod onboarding_modal;
  58pub mod open_ai_response;
  59mod prediction;
  60pub mod sweep_ai;
  61pub mod udiff;
  62mod xml_edits;
  63mod zed_edit_prediction_delegate;
  64pub mod zeta1;
  65pub mod zeta2;
  66
  67#[cfg(test)]
  68mod edit_prediction_tests;
  69
  70use crate::license_detection::LicenseDetectionWatcher;
  71use crate::mercury::Mercury;
  72use crate::onboarding_modal::ZedPredictModal;
  73pub use crate::prediction::EditPrediction;
  74pub use crate::prediction::EditPredictionId;
  75pub use crate::prediction::EditPredictionInputs;
  76use crate::prediction::EditPredictionResult;
  77pub use crate::sweep_ai::SweepAi;
  78pub use telemetry_events::EditPredictionRating;
  79pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  80
  81actions!(
  82    edit_prediction,
  83    [
  84        /// Resets the edit prediction onboarding state.
  85        ResetOnboarding,
  86        /// Clears the edit prediction history.
  87        ClearHistory,
  88    ]
  89);
  90
  91/// Maximum number of events to track.
  92const EVENT_COUNT_MAX: usize = 6;
  93const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  94const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  95const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  96
  97pub struct SweepFeatureFlag;
  98
  99impl FeatureFlag for SweepFeatureFlag {
 100    const NAME: &str = "sweep-ai";
 101}
 102
 103pub struct MercuryFeatureFlag;
 104
 105impl FeatureFlag for MercuryFeatureFlag {
 106    const NAME: &str = "mercury";
 107}
 108
 109pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 110    context: EditPredictionExcerptOptions {
 111        max_bytes: 512,
 112        min_bytes: 128,
 113        target_before_cursor_over_total_bytes: 0.5,
 114    },
 115    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 116    prompt_format: PromptFormat::DEFAULT,
 117};
 118
 119static USE_OLLAMA: LazyLock<bool> =
 120    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 121
 122static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 123    match env::var("ZED_ZETA2_MODEL").as_deref() {
 124        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 125        Ok(model) => model,
 126        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 127        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 128    }
 129    .to_string()
 130});
 131static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 132    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 133        if *USE_OLLAMA {
 134            Some("http://localhost:11434/v1/chat/completions".into())
 135        } else {
 136            None
 137        }
 138    })
 139});
 140
 141pub struct Zeta2FeatureFlag;
 142
 143impl FeatureFlag for Zeta2FeatureFlag {
 144    const NAME: &'static str = "zeta2";
 145
 146    fn enabled_for_staff() -> bool {
 147        true
 148    }
 149}
 150
 151#[derive(Clone)]
 152struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 153
 154impl Global for EditPredictionStoreGlobal {}
 155
 156pub struct EditPredictionStore {
 157    client: Arc<Client>,
 158    user_store: Entity<UserStore>,
 159    llm_token: LlmApiToken,
 160    _llm_token_subscription: Subscription,
 161    projects: HashMap<EntityId, ProjectState>,
 162    use_context: bool,
 163    options: ZetaOptions,
 164    update_required: bool,
 165    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 166    #[cfg(feature = "eval-support")]
 167    eval_cache: Option<Arc<dyn EvalCache>>,
 168    edit_prediction_model: EditPredictionModel,
 169    pub sweep_ai: SweepAi,
 170    pub mercury: Mercury,
 171    data_collection_choice: DataCollectionChoice,
 172    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 173    shown_predictions: VecDeque<EditPrediction>,
 174    rated_predictions: HashSet<EditPredictionId>,
 175}
 176
 177#[derive(Copy, Clone, Default, PartialEq, Eq)]
 178pub enum EditPredictionModel {
 179    #[default]
 180    Zeta1,
 181    Zeta2,
 182    Sweep,
 183    Mercury,
 184}
 185
 186#[derive(Debug, Clone, PartialEq)]
 187pub struct ZetaOptions {
 188    pub context: EditPredictionExcerptOptions,
 189    pub max_prompt_bytes: usize,
 190    pub prompt_format: predict_edits_v3::PromptFormat,
 191}
 192
 193#[derive(Debug)]
 194pub enum DebugEvent {
 195    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 196    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 197    EditPredictionRequested(EditPredictionRequestedDebugEvent),
 198}
 199
 200#[derive(Debug)]
 201pub struct ContextRetrievalStartedDebugEvent {
 202    pub project_entity_id: EntityId,
 203    pub timestamp: Instant,
 204    pub search_prompt: String,
 205}
 206
 207#[derive(Debug)]
 208pub struct ContextRetrievalFinishedDebugEvent {
 209    pub project_entity_id: EntityId,
 210    pub timestamp: Instant,
 211    pub metadata: Vec<(&'static str, SharedString)>,
 212}
 213
 214#[derive(Debug)]
 215pub struct EditPredictionRequestedDebugEvent {
 216    pub inputs: EditPredictionInputs,
 217    pub retrieval_time: Duration,
 218    pub buffer: WeakEntity<Buffer>,
 219    pub position: Anchor,
 220    pub local_prompt: Result<String, String>,
 221    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 222}
 223
 224pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 225
 226struct ProjectState {
 227    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 228    last_event: Option<LastEvent>,
 229    recent_paths: VecDeque<ProjectPath>,
 230    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 231    current_prediction: Option<CurrentEditPrediction>,
 232    next_pending_prediction_id: usize,
 233    pending_predictions: ArrayVec<PendingPrediction, 2>,
 234    context_updates_tx: smol::channel::Sender<()>,
 235    context_updates_rx: smol::channel::Receiver<()>,
 236    last_prediction_refresh: Option<(EntityId, Instant)>,
 237    cancelled_predictions: HashSet<usize>,
 238    context: Entity<RelatedExcerptStore>,
 239    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 240    _subscription: gpui::Subscription,
 241}
 242
 243impl ProjectState {
 244    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 245        self.events
 246            .iter()
 247            .cloned()
 248            .chain(
 249                self.last_event
 250                    .as_ref()
 251                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 252            )
 253            .collect()
 254    }
 255
 256    fn cancel_pending_prediction(
 257        &mut self,
 258        pending_prediction: PendingPrediction,
 259        cx: &mut Context<EditPredictionStore>,
 260    ) {
 261        self.cancelled_predictions.insert(pending_prediction.id);
 262
 263        cx.spawn(async move |this, cx| {
 264            let Some(prediction_id) = pending_prediction.task.await else {
 265                return;
 266            };
 267
 268            this.update(cx, |this, _cx| {
 269                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 270            })
 271            .ok();
 272        })
 273        .detach()
 274    }
 275}
 276
 277#[derive(Debug, Clone)]
 278struct CurrentEditPrediction {
 279    pub requested_by: PredictionRequestedBy,
 280    pub prediction: EditPrediction,
 281    pub was_shown: bool,
 282}
 283
 284impl CurrentEditPrediction {
 285    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 286        let Some(new_edits) = self
 287            .prediction
 288            .interpolate(&self.prediction.buffer.read(cx))
 289        else {
 290            return false;
 291        };
 292
 293        if self.prediction.buffer != old_prediction.prediction.buffer {
 294            return true;
 295        }
 296
 297        let Some(old_edits) = old_prediction
 298            .prediction
 299            .interpolate(&old_prediction.prediction.buffer.read(cx))
 300        else {
 301            return true;
 302        };
 303
 304        let requested_by_buffer_id = self.requested_by.buffer_id();
 305
 306        // This reduces the occurrence of UI thrash from replacing edits
 307        //
 308        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 309        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 310            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 311            && old_edits.len() == 1
 312            && new_edits.len() == 1
 313        {
 314            let (old_range, old_text) = &old_edits[0];
 315            let (new_range, new_text) = &new_edits[0];
 316            new_range == old_range && new_text.starts_with(old_text.as_ref())
 317        } else {
 318            true
 319        }
 320    }
 321}
 322
 323#[derive(Debug, Clone)]
 324enum PredictionRequestedBy {
 325    DiagnosticsUpdate,
 326    Buffer(EntityId),
 327}
 328
 329impl PredictionRequestedBy {
 330    pub fn buffer_id(&self) -> Option<EntityId> {
 331        match self {
 332            PredictionRequestedBy::DiagnosticsUpdate => None,
 333            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 334        }
 335    }
 336}
 337
 338#[derive(Debug)]
 339struct PendingPrediction {
 340    id: usize,
 341    task: Task<Option<EditPredictionId>>,
 342}
 343
 344/// A prediction from the perspective of a buffer.
 345#[derive(Debug)]
 346enum BufferEditPrediction<'a> {
 347    Local { prediction: &'a EditPrediction },
 348    Jump { prediction: &'a EditPrediction },
 349}
 350
 351#[cfg(test)]
 352impl std::ops::Deref for BufferEditPrediction<'_> {
 353    type Target = EditPrediction;
 354
 355    fn deref(&self) -> &Self::Target {
 356        match self {
 357            BufferEditPrediction::Local { prediction } => prediction,
 358            BufferEditPrediction::Jump { prediction } => prediction,
 359        }
 360    }
 361}
 362
 363struct RegisteredBuffer {
 364    snapshot: BufferSnapshot,
 365    _subscriptions: [gpui::Subscription; 2],
 366}
 367
 368struct LastEvent {
 369    old_snapshot: BufferSnapshot,
 370    new_snapshot: BufferSnapshot,
 371    end_edit_anchor: Option<Anchor>,
 372}
 373
 374impl LastEvent {
 375    pub fn finalize(
 376        &self,
 377        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 378        cx: &App,
 379    ) -> Option<Arc<predict_edits_v3::Event>> {
 380        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 381        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 382
 383        let file = self.new_snapshot.file();
 384        let old_file = self.old_snapshot.file();
 385
 386        let in_open_source_repo = [file, old_file].iter().all(|file| {
 387            file.is_some_and(|file| {
 388                license_detection_watchers
 389                    .get(&file.worktree_id(cx))
 390                    .is_some_and(|watcher| watcher.is_project_open_source())
 391            })
 392        });
 393
 394        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 395
 396        if path == old_path && diff.is_empty() {
 397            None
 398        } else {
 399            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 400                old_path,
 401                path,
 402                diff,
 403                in_open_source_repo,
 404                // TODO: Actually detect if this edit was predicted or not
 405                predicted: false,
 406            }))
 407        }
 408    }
 409}
 410
 411fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 412    if let Some(file) = snapshot.file() {
 413        file.full_path(cx).into()
 414    } else {
 415        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 416    }
 417}
 418
 419impl EditPredictionStore {
 420    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 421        cx.try_global::<EditPredictionStoreGlobal>()
 422            .map(|global| global.0.clone())
 423    }
 424
 425    pub fn global(
 426        client: &Arc<Client>,
 427        user_store: &Entity<UserStore>,
 428        cx: &mut App,
 429    ) -> Entity<Self> {
 430        cx.try_global::<EditPredictionStoreGlobal>()
 431            .map(|global| global.0.clone())
 432            .unwrap_or_else(|| {
 433                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 434                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 435                ep_store
 436            })
 437    }
 438
 439    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 440        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 441        let data_collection_choice = Self::load_data_collection_choice();
 442
 443        let llm_token = LlmApiToken::default();
 444
 445        let (reject_tx, reject_rx) = mpsc::unbounded();
 446        cx.background_spawn({
 447            let client = client.clone();
 448            let llm_token = llm_token.clone();
 449            let app_version = AppVersion::global(cx);
 450            let background_executor = cx.background_executor().clone();
 451            async move {
 452                Self::handle_rejected_predictions(
 453                    reject_rx,
 454                    client,
 455                    llm_token,
 456                    app_version,
 457                    background_executor,
 458                )
 459                .await
 460            }
 461        })
 462        .detach();
 463
 464        let mut this = Self {
 465            projects: HashMap::default(),
 466            client,
 467            user_store,
 468            options: DEFAULT_OPTIONS,
 469            use_context: false,
 470            llm_token,
 471            _llm_token_subscription: cx.subscribe(
 472                &refresh_llm_token_listener,
 473                |this, _listener, _event, cx| {
 474                    let client = this.client.clone();
 475                    let llm_token = this.llm_token.clone();
 476                    cx.spawn(async move |_this, _cx| {
 477                        llm_token.refresh(&client).await?;
 478                        anyhow::Ok(())
 479                    })
 480                    .detach_and_log_err(cx);
 481                },
 482            ),
 483            update_required: false,
 484            debug_tx: None,
 485            #[cfg(feature = "eval-support")]
 486            eval_cache: None,
 487            edit_prediction_model: EditPredictionModel::Zeta2,
 488            sweep_ai: SweepAi::new(cx),
 489            mercury: Mercury::new(cx),
 490            data_collection_choice,
 491            reject_predictions_tx: reject_tx,
 492            rated_predictions: Default::default(),
 493            shown_predictions: Default::default(),
 494        };
 495
 496        this.configure_context_retrieval(cx);
 497        let weak_this = cx.weak_entity();
 498        cx.on_flags_ready(move |_, cx| {
 499            weak_this
 500                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 501                .ok();
 502        })
 503        .detach();
 504        cx.observe_global::<SettingsStore>(|this, cx| {
 505            this.configure_context_retrieval(cx);
 506        })
 507        .detach();
 508
 509        this
 510    }
 511
 512    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 513        self.edit_prediction_model = model;
 514    }
 515
 516    pub fn has_sweep_api_token(&self) -> bool {
 517        self.sweep_ai
 518            .api_token
 519            .clone()
 520            .now_or_never()
 521            .flatten()
 522            .is_some()
 523    }
 524
 525    pub fn has_mercury_api_token(&self) -> bool {
 526        self.mercury
 527            .api_token
 528            .clone()
 529            .now_or_never()
 530            .flatten()
 531            .is_some()
 532    }
 533
 534    #[cfg(feature = "eval-support")]
 535    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 536        self.eval_cache = Some(cache);
 537    }
 538
 539    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
 540        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 541        self.debug_tx = Some(debug_watch_tx);
 542        debug_watch_rx
 543    }
 544
 545    pub fn options(&self) -> &ZetaOptions {
 546        &self.options
 547    }
 548
 549    pub fn set_options(&mut self, options: ZetaOptions) {
 550        self.options = options;
 551    }
 552
 553    pub fn set_use_context(&mut self, use_context: bool) {
 554        self.use_context = use_context;
 555    }
 556
 557    pub fn clear_history(&mut self) {
 558        for project_state in self.projects.values_mut() {
 559            project_state.events.clear();
 560        }
 561    }
 562
 563    pub fn context_for_project<'a>(
 564        &'a self,
 565        project: &Entity<Project>,
 566        cx: &'a App,
 567    ) -> &'a [RelatedFile] {
 568        self.projects
 569            .get(&project.entity_id())
 570            .map(|project| project.context.read(cx).related_files())
 571            .unwrap_or(&[])
 572    }
 573
 574    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 575        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 576            self.user_store.read(cx).edit_prediction_usage()
 577        } else {
 578            None
 579        }
 580    }
 581
 582    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 583        self.get_or_init_project(project, cx);
 584    }
 585
 586    pub fn register_buffer(
 587        &mut self,
 588        buffer: &Entity<Buffer>,
 589        project: &Entity<Project>,
 590        cx: &mut Context<Self>,
 591    ) {
 592        let project_state = self.get_or_init_project(project, cx);
 593        Self::register_buffer_impl(project_state, buffer, project, cx);
 594    }
 595
 596    fn get_or_init_project(
 597        &mut self,
 598        project: &Entity<Project>,
 599        cx: &mut Context<Self>,
 600    ) -> &mut ProjectState {
 601        let entity_id = project.entity_id();
 602        let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
 603        self.projects
 604            .entry(entity_id)
 605            .or_insert_with(|| ProjectState {
 606                context: {
 607                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 608                    cx.subscribe(
 609                        &related_excerpt_store,
 610                        move |this, _, event, _| match event {
 611                            RelatedExcerptStoreEvent::StartedRefresh => {
 612                                if let Some(debug_tx) = this.debug_tx.clone() {
 613                                    debug_tx
 614                                        .unbounded_send(DebugEvent::ContextRetrievalStarted(
 615                                            ContextRetrievalStartedDebugEvent {
 616                                                project_entity_id: entity_id,
 617                                                timestamp: Instant::now(),
 618                                                search_prompt: String::new(),
 619                                            },
 620                                        ))
 621                                        .ok();
 622                                }
 623                            }
 624                            RelatedExcerptStoreEvent::FinishedRefresh {
 625                                cache_hit_count,
 626                                cache_miss_count,
 627                                mean_definition_latency,
 628                                max_definition_latency,
 629                            } => {
 630                                if let Some(debug_tx) = this.debug_tx.clone() {
 631                                    debug_tx
 632                                        .unbounded_send(DebugEvent::ContextRetrievalFinished(
 633                                            ContextRetrievalFinishedDebugEvent {
 634                                                project_entity_id: entity_id,
 635                                                timestamp: Instant::now(),
 636                                                metadata: vec![
 637                                                    (
 638                                                        "Cache Hits",
 639                                                        format!(
 640                                                            "{}/{}",
 641                                                            cache_hit_count,
 642                                                            cache_hit_count + cache_miss_count
 643                                                        )
 644                                                        .into(),
 645                                                    ),
 646                                                    (
 647                                                        "Max LSP Time",
 648                                                        format!(
 649                                                            "{} ms",
 650                                                            max_definition_latency.as_millis()
 651                                                        )
 652                                                        .into(),
 653                                                    ),
 654                                                    (
 655                                                        "Mean LSP Time",
 656                                                        format!(
 657                                                            "{} ms",
 658                                                            mean_definition_latency.as_millis()
 659                                                        )
 660                                                        .into(),
 661                                                    ),
 662                                                ],
 663                                            },
 664                                        ))
 665                                        .ok();
 666                                }
 667                                if let Some(project_state) = this.projects.get(&entity_id) {
 668                                    project_state.context_updates_tx.send_blocking(()).ok();
 669                                }
 670                            }
 671                        },
 672                    )
 673                    .detach();
 674                    related_excerpt_store
 675                },
 676                events: VecDeque::new(),
 677                last_event: None,
 678                recent_paths: VecDeque::new(),
 679                context_updates_rx,
 680                context_updates_tx,
 681                registered_buffers: HashMap::default(),
 682                current_prediction: None,
 683                cancelled_predictions: HashSet::default(),
 684                pending_predictions: ArrayVec::new(),
 685                next_pending_prediction_id: 0,
 686                last_prediction_refresh: None,
 687                license_detection_watchers: HashMap::default(),
 688                _subscription: cx.subscribe(&project, Self::handle_project_event),
 689            })
 690    }
 691
 692    pub fn project_context_updates(
 693        &self,
 694        project: &Entity<Project>,
 695    ) -> Option<smol::channel::Receiver<()>> {
 696        let project_state = self.projects.get(&project.entity_id())?;
 697        Some(project_state.context_updates_rx.clone())
 698    }
 699
 700    fn handle_project_event(
 701        &mut self,
 702        project: Entity<Project>,
 703        event: &project::Event,
 704        cx: &mut Context<Self>,
 705    ) {
 706        // TODO [zeta2] init with recent paths
 707        match event {
 708            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 709                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 710                    return;
 711                };
 712                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 713                if let Some(path) = path {
 714                    if let Some(ix) = project_state
 715                        .recent_paths
 716                        .iter()
 717                        .position(|probe| probe == &path)
 718                    {
 719                        project_state.recent_paths.remove(ix);
 720                    }
 721                    project_state.recent_paths.push_front(path);
 722                }
 723            }
 724            project::Event::DiagnosticsUpdated { .. } => {
 725                if cx.has_flag::<Zeta2FeatureFlag>() {
 726                    self.refresh_prediction_from_diagnostics(project, cx);
 727                }
 728            }
 729            _ => (),
 730        }
 731    }
 732
 733    fn register_buffer_impl<'a>(
 734        project_state: &'a mut ProjectState,
 735        buffer: &Entity<Buffer>,
 736        project: &Entity<Project>,
 737        cx: &mut Context<Self>,
 738    ) -> &'a mut RegisteredBuffer {
 739        let buffer_id = buffer.entity_id();
 740
 741        if let Some(file) = buffer.read(cx).file() {
 742            let worktree_id = file.worktree_id(cx);
 743            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 744                project_state
 745                    .license_detection_watchers
 746                    .entry(worktree_id)
 747                    .or_insert_with(|| {
 748                        let project_entity_id = project.entity_id();
 749                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 750                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 751                            else {
 752                                return;
 753                            };
 754                            project_state
 755                                .license_detection_watchers
 756                                .remove(&worktree_id);
 757                        })
 758                        .detach();
 759                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 760                    });
 761            }
 762        }
 763
 764        match project_state.registered_buffers.entry(buffer_id) {
 765            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 766            hash_map::Entry::Vacant(entry) => {
 767                let snapshot = buffer.read(cx).snapshot();
 768                let project_entity_id = project.entity_id();
 769                entry.insert(RegisteredBuffer {
 770                    snapshot,
 771                    _subscriptions: [
 772                        cx.subscribe(buffer, {
 773                            let project = project.downgrade();
 774                            move |this, buffer, event, cx| {
 775                                if let language::BufferEvent::Edited = event
 776                                    && let Some(project) = project.upgrade()
 777                                {
 778                                    this.report_changes_for_buffer(&buffer, &project, cx);
 779                                }
 780                            }
 781                        }),
 782                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 783                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 784                            else {
 785                                return;
 786                            };
 787                            project_state.registered_buffers.remove(&buffer_id);
 788                        }),
 789                    ],
 790                })
 791            }
 792        }
 793    }
 794
 795    fn report_changes_for_buffer(
 796        &mut self,
 797        buffer: &Entity<Buffer>,
 798        project: &Entity<Project>,
 799        cx: &mut Context<Self>,
 800    ) {
 801        let project_state = self.get_or_init_project(project, cx);
 802        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 803
 804        let new_snapshot = buffer.read(cx).snapshot();
 805        if new_snapshot.version == registered_buffer.snapshot.version {
 806            return;
 807        }
 808
 809        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 810        let end_edit_anchor = new_snapshot
 811            .anchored_edits_since::<Point>(&old_snapshot.version)
 812            .last()
 813            .map(|(_, range)| range.end);
 814        let events = &mut project_state.events;
 815
 816        if let Some(LastEvent {
 817            new_snapshot: last_new_snapshot,
 818            end_edit_anchor: last_end_edit_anchor,
 819            ..
 820        }) = project_state.last_event.as_mut()
 821        {
 822            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 823                == last_new_snapshot.remote_id()
 824                && old_snapshot.version == last_new_snapshot.version;
 825
 826            let should_coalesce = is_next_snapshot_of_same_buffer
 827                && end_edit_anchor
 828                    .as_ref()
 829                    .zip(last_end_edit_anchor.as_ref())
 830                    .is_some_and(|(a, b)| {
 831                        let a = a.to_point(&new_snapshot);
 832                        let b = b.to_point(&new_snapshot);
 833                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 834                    });
 835
 836            if should_coalesce {
 837                *last_end_edit_anchor = end_edit_anchor;
 838                *last_new_snapshot = new_snapshot;
 839                return;
 840            }
 841        }
 842
 843        if events.len() + 1 >= EVENT_COUNT_MAX {
 844            events.pop_front();
 845        }
 846
 847        if let Some(event) = project_state.last_event.take() {
 848            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 849        }
 850
 851        project_state.last_event = Some(LastEvent {
 852            old_snapshot,
 853            new_snapshot,
 854            end_edit_anchor,
 855        });
 856    }
 857
 858    fn current_prediction_for_buffer(
 859        &self,
 860        buffer: &Entity<Buffer>,
 861        project: &Entity<Project>,
 862        cx: &App,
 863    ) -> Option<BufferEditPrediction<'_>> {
 864        let project_state = self.projects.get(&project.entity_id())?;
 865
 866        let CurrentEditPrediction {
 867            requested_by,
 868            prediction,
 869            ..
 870        } = project_state.current_prediction.as_ref()?;
 871
 872        if prediction.targets_buffer(buffer.read(cx)) {
 873            Some(BufferEditPrediction::Local { prediction })
 874        } else {
 875            let show_jump = match requested_by {
 876                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 877                    requested_by_buffer_id == &buffer.entity_id()
 878                }
 879                PredictionRequestedBy::DiagnosticsUpdate => true,
 880            };
 881
 882            if show_jump {
 883                Some(BufferEditPrediction::Jump { prediction })
 884            } else {
 885                None
 886            }
 887        }
 888    }
 889
 890    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 891        match self.edit_prediction_model {
 892            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
 893            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
 894        }
 895
 896        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 897            return;
 898        };
 899
 900        let Some(prediction) = project_state.current_prediction.take() else {
 901            return;
 902        };
 903        let request_id = prediction.prediction.id.to_string();
 904        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 905            project_state.cancel_pending_prediction(pending_prediction, cx);
 906        }
 907
 908        let client = self.client.clone();
 909        let llm_token = self.llm_token.clone();
 910        let app_version = AppVersion::global(cx);
 911        cx.spawn(async move |this, cx| {
 912            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 913                http_client::Url::parse(&predict_edits_url)?
 914            } else {
 915                client
 916                    .http_client()
 917                    .build_zed_llm_url("/predict_edits/accept", &[])?
 918            };
 919
 920            let response = cx
 921                .background_spawn(Self::send_api_request::<()>(
 922                    move |builder| {
 923                        let req = builder.uri(url.as_ref()).body(
 924                            serde_json::to_string(&AcceptEditPredictionBody {
 925                                request_id: request_id.clone(),
 926                            })?
 927                            .into(),
 928                        );
 929                        Ok(req?)
 930                    },
 931                    client,
 932                    llm_token,
 933                    app_version,
 934                ))
 935                .await;
 936
 937            Self::handle_api_response(&this, response, cx)?;
 938            anyhow::Ok(())
 939        })
 940        .detach_and_log_err(cx);
 941    }
 942
 943    async fn handle_rejected_predictions(
 944        rx: UnboundedReceiver<EditPredictionRejection>,
 945        client: Arc<Client>,
 946        llm_token: LlmApiToken,
 947        app_version: Version,
 948        background_executor: BackgroundExecutor,
 949    ) {
 950        let mut rx = std::pin::pin!(rx.peekable());
 951        let mut batched = Vec::new();
 952
 953        while let Some(rejection) = rx.next().await {
 954            batched.push(rejection);
 955
 956            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
 957                select_biased! {
 958                    next = rx.as_mut().peek().fuse() => {
 959                        if next.is_some() {
 960                            continue;
 961                        }
 962                    }
 963                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
 964                }
 965            }
 966
 967            let url = client
 968                .http_client()
 969                .build_zed_llm_url("/predict_edits/reject", &[])
 970                .unwrap();
 971
 972            let flush_count = batched
 973                .len()
 974                // in case items have accumulated after failure
 975                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
 976            let start = batched.len() - flush_count;
 977
 978            let body = RejectEditPredictionsBodyRef {
 979                rejections: &batched[start..],
 980            };
 981
 982            let result = Self::send_api_request::<()>(
 983                |builder| {
 984                    let req = builder
 985                        .uri(url.as_ref())
 986                        .body(serde_json::to_string(&body)?.into());
 987                    anyhow::Ok(req?)
 988                },
 989                client.clone(),
 990                llm_token.clone(),
 991                app_version.clone(),
 992            )
 993            .await;
 994
 995            if result.log_err().is_some() {
 996                batched.drain(start..);
 997            }
 998        }
 999    }
1000
1001    fn reject_current_prediction(
1002        &mut self,
1003        reason: EditPredictionRejectReason,
1004        project: &Entity<Project>,
1005    ) {
1006        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1007            project_state.pending_predictions.clear();
1008            if let Some(prediction) = project_state.current_prediction.take() {
1009                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1010            }
1011        };
1012    }
1013
1014    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1015        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1016            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1017                if !current_prediction.was_shown {
1018                    current_prediction.was_shown = true;
1019                    self.shown_predictions
1020                        .push_front(current_prediction.prediction.clone());
1021                    if self.shown_predictions.len() > 50 {
1022                        let completion = self.shown_predictions.pop_back().unwrap();
1023                        self.rated_predictions.remove(&completion.id);
1024                    }
1025                }
1026            }
1027        }
1028    }
1029
1030    fn reject_prediction(
1031        &mut self,
1032        prediction_id: EditPredictionId,
1033        reason: EditPredictionRejectReason,
1034        was_shown: bool,
1035    ) {
1036        match self.edit_prediction_model {
1037            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1038            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1039        }
1040
1041        self.reject_predictions_tx
1042            .unbounded_send(EditPredictionRejection {
1043                request_id: prediction_id.to_string(),
1044                reason,
1045                was_shown,
1046            })
1047            .log_err();
1048    }
1049
1050    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1051        self.projects
1052            .get(&project.entity_id())
1053            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1054    }
1055
1056    pub fn refresh_prediction_from_buffer(
1057        &mut self,
1058        project: Entity<Project>,
1059        buffer: Entity<Buffer>,
1060        position: language::Anchor,
1061        cx: &mut Context<Self>,
1062    ) {
1063        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1064            let Some(request_task) = this
1065                .update(cx, |this, cx| {
1066                    this.request_prediction(
1067                        &project,
1068                        &buffer,
1069                        position,
1070                        PredictEditsRequestTrigger::Other,
1071                        cx,
1072                    )
1073                })
1074                .log_err()
1075            else {
1076                return Task::ready(anyhow::Ok(None));
1077            };
1078
1079            cx.spawn(async move |_cx| {
1080                request_task.await.map(|prediction_result| {
1081                    prediction_result.map(|prediction_result| {
1082                        (
1083                            prediction_result,
1084                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1085                        )
1086                    })
1087                })
1088            })
1089        })
1090    }
1091
1092    pub fn refresh_prediction_from_diagnostics(
1093        &mut self,
1094        project: Entity<Project>,
1095        cx: &mut Context<Self>,
1096    ) {
1097        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1098            return;
1099        };
1100
1101        // Prefer predictions from buffer
1102        if project_state.current_prediction.is_some() {
1103            return;
1104        };
1105
1106        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1107            let Some(open_buffer_task) = project
1108                .update(cx, |project, cx| {
1109                    project
1110                        .active_entry()
1111                        .and_then(|entry| project.path_for_entry(entry, cx))
1112                        .map(|path| project.open_buffer(path, cx))
1113                })
1114                .log_err()
1115                .flatten()
1116            else {
1117                return Task::ready(anyhow::Ok(None));
1118            };
1119
1120            cx.spawn(async move |cx| {
1121                let active_buffer = open_buffer_task.await?;
1122                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1123
1124                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1125                    active_buffer,
1126                    &snapshot,
1127                    Default::default(),
1128                    Default::default(),
1129                    &project,
1130                    cx,
1131                )
1132                .await?
1133                else {
1134                    return anyhow::Ok(None);
1135                };
1136
1137                let Some(prediction_result) = this
1138                    .update(cx, |this, cx| {
1139                        this.request_prediction(
1140                            &project,
1141                            &jump_buffer,
1142                            jump_position,
1143                            PredictEditsRequestTrigger::Diagnostics,
1144                            cx,
1145                        )
1146                    })?
1147                    .await?
1148                else {
1149                    return anyhow::Ok(None);
1150                };
1151
1152                this.update(cx, |this, cx| {
1153                    Some((
1154                        if this
1155                            .get_or_init_project(&project, cx)
1156                            .current_prediction
1157                            .is_none()
1158                        {
1159                            prediction_result
1160                        } else {
1161                            EditPredictionResult {
1162                                id: prediction_result.id,
1163                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1164                            }
1165                        },
1166                        PredictionRequestedBy::DiagnosticsUpdate,
1167                    ))
1168                })
1169            })
1170        });
1171    }
1172
1173    #[cfg(not(test))]
1174    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1175    #[cfg(test)]
1176    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1177
1178    fn queue_prediction_refresh(
1179        &mut self,
1180        project: Entity<Project>,
1181        throttle_entity: EntityId,
1182        cx: &mut Context<Self>,
1183        do_refresh: impl FnOnce(
1184            WeakEntity<Self>,
1185            &mut AsyncApp,
1186        )
1187            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1188        + 'static,
1189    ) {
1190        let project_state = self.get_or_init_project(&project, cx);
1191        let pending_prediction_id = project_state.next_pending_prediction_id;
1192        project_state.next_pending_prediction_id += 1;
1193        let last_request = project_state.last_prediction_refresh;
1194
1195        let task = cx.spawn(async move |this, cx| {
1196            if let Some((last_entity, last_timestamp)) = last_request
1197                && throttle_entity == last_entity
1198                && let Some(timeout) =
1199                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1200            {
1201                cx.background_executor().timer(timeout).await;
1202            }
1203
1204            // If this task was cancelled before the throttle timeout expired,
1205            // do not perform a request.
1206            let mut is_cancelled = true;
1207            this.update(cx, |this, cx| {
1208                let project_state = this.get_or_init_project(&project, cx);
1209                if !project_state
1210                    .cancelled_predictions
1211                    .remove(&pending_prediction_id)
1212                {
1213                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1214                    is_cancelled = false;
1215                }
1216            })
1217            .ok();
1218            if is_cancelled {
1219                return None;
1220            }
1221
1222            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1223            let new_prediction_id = new_prediction_result
1224                .as_ref()
1225                .map(|(prediction, _)| prediction.id.clone());
1226
1227            // When a prediction completes, remove it from the pending list, and cancel
1228            // any pending predictions that were enqueued before it.
1229            this.update(cx, |this, cx| {
1230                let project_state = this.get_or_init_project(&project, cx);
1231
1232                let is_cancelled = project_state
1233                    .cancelled_predictions
1234                    .remove(&pending_prediction_id);
1235
1236                let new_current_prediction = if !is_cancelled
1237                    && let Some((prediction_result, requested_by)) = new_prediction_result
1238                {
1239                    match prediction_result.prediction {
1240                        Ok(prediction) => {
1241                            let new_prediction = CurrentEditPrediction {
1242                                requested_by,
1243                                prediction,
1244                                was_shown: false,
1245                            };
1246
1247                            if let Some(current_prediction) =
1248                                project_state.current_prediction.as_ref()
1249                            {
1250                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1251                                {
1252                                    this.reject_current_prediction(
1253                                        EditPredictionRejectReason::Replaced,
1254                                        &project,
1255                                    );
1256
1257                                    Some(new_prediction)
1258                                } else {
1259                                    this.reject_prediction(
1260                                        new_prediction.prediction.id,
1261                                        EditPredictionRejectReason::CurrentPreferred,
1262                                        false,
1263                                    );
1264                                    None
1265                                }
1266                            } else {
1267                                Some(new_prediction)
1268                            }
1269                        }
1270                        Err(reject_reason) => {
1271                            this.reject_prediction(prediction_result.id, reject_reason, false);
1272                            None
1273                        }
1274                    }
1275                } else {
1276                    None
1277                };
1278
1279                let project_state = this.get_or_init_project(&project, cx);
1280
1281                if let Some(new_prediction) = new_current_prediction {
1282                    project_state.current_prediction = Some(new_prediction);
1283                }
1284
1285                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1286                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1287                    if pending_prediction.id == pending_prediction_id {
1288                        pending_predictions.remove(ix);
1289                        for pending_prediction in pending_predictions.drain(0..ix) {
1290                            project_state.cancel_pending_prediction(pending_prediction, cx)
1291                        }
1292                        break;
1293                    }
1294                }
1295                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1296                cx.notify();
1297            })
1298            .ok();
1299
1300            new_prediction_id
1301        });
1302
1303        if project_state.pending_predictions.len() <= 1 {
1304            project_state.pending_predictions.push(PendingPrediction {
1305                id: pending_prediction_id,
1306                task,
1307            });
1308        } else if project_state.pending_predictions.len() == 2 {
1309            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1310            project_state.pending_predictions.push(PendingPrediction {
1311                id: pending_prediction_id,
1312                task,
1313            });
1314            project_state.cancel_pending_prediction(pending_prediction, cx);
1315        }
1316    }
1317
1318    pub fn request_prediction(
1319        &mut self,
1320        project: &Entity<Project>,
1321        active_buffer: &Entity<Buffer>,
1322        position: language::Anchor,
1323        trigger: PredictEditsRequestTrigger,
1324        cx: &mut Context<Self>,
1325    ) -> Task<Result<Option<EditPredictionResult>>> {
1326        self.request_prediction_internal(
1327            project.clone(),
1328            active_buffer.clone(),
1329            position,
1330            trigger,
1331            cx.has_flag::<Zeta2FeatureFlag>(),
1332            cx,
1333        )
1334    }
1335
1336    fn request_prediction_internal(
1337        &mut self,
1338        project: Entity<Project>,
1339        active_buffer: Entity<Buffer>,
1340        position: language::Anchor,
1341        trigger: PredictEditsRequestTrigger,
1342        allow_jump: bool,
1343        cx: &mut Context<Self>,
1344    ) -> Task<Result<Option<EditPredictionResult>>> {
1345        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1346
1347        self.get_or_init_project(&project, cx);
1348        let project_state = self.projects.get(&project.entity_id()).unwrap();
1349        let events = project_state.events(cx);
1350        let has_events = !events.is_empty();
1351
1352        let snapshot = active_buffer.read(cx).snapshot();
1353        let cursor_point = position.to_point(&snapshot);
1354        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1355        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1356        let diagnostic_search_range =
1357            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1358
1359        let related_files = if self.use_context {
1360            self.context_for_project(&project, cx).to_vec()
1361        } else {
1362            Vec::new()
1363        };
1364
1365        let task = match self.edit_prediction_model {
1366            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
1367                self,
1368                &project,
1369                &active_buffer,
1370                snapshot.clone(),
1371                position,
1372                events,
1373                trigger,
1374                cx,
1375            ),
1376            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1377                self,
1378                &project,
1379                &active_buffer,
1380                snapshot.clone(),
1381                position,
1382                events,
1383                related_files,
1384                trigger,
1385                cx,
1386            ),
1387            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1388                &project,
1389                &active_buffer,
1390                snapshot.clone(),
1391                position,
1392                events,
1393                &project_state.recent_paths,
1394                related_files,
1395                diagnostic_search_range.clone(),
1396                cx,
1397            ),
1398            EditPredictionModel::Mercury => self.mercury.request_prediction(
1399                &project,
1400                &active_buffer,
1401                snapshot.clone(),
1402                position,
1403                events,
1404                &project_state.recent_paths,
1405                related_files,
1406                diagnostic_search_range.clone(),
1407                cx,
1408            ),
1409        };
1410
1411        cx.spawn(async move |this, cx| {
1412            let prediction = task.await?;
1413
1414            if prediction.is_none() && allow_jump {
1415                let cursor_point = position.to_point(&snapshot);
1416                if has_events
1417                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1418                        active_buffer.clone(),
1419                        &snapshot,
1420                        diagnostic_search_range,
1421                        cursor_point,
1422                        &project,
1423                        cx,
1424                    )
1425                    .await?
1426                {
1427                    return this
1428                        .update(cx, |this, cx| {
1429                            this.request_prediction_internal(
1430                                project,
1431                                jump_buffer,
1432                                jump_position,
1433                                trigger,
1434                                false,
1435                                cx,
1436                            )
1437                        })?
1438                        .await;
1439                }
1440
1441                return anyhow::Ok(None);
1442            }
1443
1444            Ok(prediction)
1445        })
1446    }
1447
1448    async fn next_diagnostic_location(
1449        active_buffer: Entity<Buffer>,
1450        active_buffer_snapshot: &BufferSnapshot,
1451        active_buffer_diagnostic_search_range: Range<Point>,
1452        active_buffer_cursor_point: Point,
1453        project: &Entity<Project>,
1454        cx: &mut AsyncApp,
1455    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1456        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1457        let mut jump_location = active_buffer_snapshot
1458            .diagnostic_groups(None)
1459            .into_iter()
1460            .filter_map(|(_, group)| {
1461                let range = &group.entries[group.primary_ix]
1462                    .range
1463                    .to_point(&active_buffer_snapshot);
1464                if range.overlaps(&active_buffer_diagnostic_search_range) {
1465                    None
1466                } else {
1467                    Some(range.start)
1468                }
1469            })
1470            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1471            .map(|position| {
1472                (
1473                    active_buffer.clone(),
1474                    active_buffer_snapshot.anchor_before(position),
1475                )
1476            });
1477
1478        if jump_location.is_none() {
1479            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1480                let file = buffer.file()?;
1481
1482                Some(ProjectPath {
1483                    worktree_id: file.worktree_id(cx),
1484                    path: file.path().clone(),
1485                })
1486            })?;
1487
1488            let buffer_task = project.update(cx, |project, cx| {
1489                let (path, _, _) = project
1490                    .diagnostic_summaries(false, cx)
1491                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1492                    .max_by_key(|(path, _, _)| {
1493                        // find the buffer with errors that shares most parent directories
1494                        path.path
1495                            .components()
1496                            .zip(
1497                                active_buffer_path
1498                                    .as_ref()
1499                                    .map(|p| p.path.components())
1500                                    .unwrap_or_default(),
1501                            )
1502                            .take_while(|(a, b)| a == b)
1503                            .count()
1504                    })?;
1505
1506                Some(project.open_buffer(path, cx))
1507            })?;
1508
1509            if let Some(buffer_task) = buffer_task {
1510                let closest_buffer = buffer_task.await?;
1511
1512                jump_location = closest_buffer
1513                    .read_with(cx, |buffer, _cx| {
1514                        buffer
1515                            .buffer_diagnostics(None)
1516                            .into_iter()
1517                            .min_by_key(|entry| entry.diagnostic.severity)
1518                            .map(|entry| entry.range.start)
1519                    })?
1520                    .map(|position| (closest_buffer, position));
1521            }
1522        }
1523
1524        anyhow::Ok(jump_location)
1525    }
1526
1527    async fn send_raw_llm_request(
1528        request: open_ai::Request,
1529        client: Arc<Client>,
1530        llm_token: LlmApiToken,
1531        app_version: Version,
1532        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1533        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1534    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1535        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1536            http_client::Url::parse(&predict_edits_url)?
1537        } else {
1538            client
1539                .http_client()
1540                .build_zed_llm_url("/predict_edits/raw", &[])?
1541        };
1542
1543        #[cfg(feature = "eval-support")]
1544        let cache_key = if let Some(cache) = eval_cache {
1545            use collections::FxHasher;
1546            use std::hash::{Hash, Hasher};
1547
1548            let mut hasher = FxHasher::default();
1549            url.hash(&mut hasher);
1550            let request_str = serde_json::to_string_pretty(&request)?;
1551            request_str.hash(&mut hasher);
1552            let hash = hasher.finish();
1553
1554            let key = (eval_cache_kind, hash);
1555            if let Some(response_str) = cache.read(key) {
1556                return Ok((serde_json::from_str(&response_str)?, None));
1557            }
1558
1559            Some((cache, request_str, key))
1560        } else {
1561            None
1562        };
1563
1564        let (response, usage) = Self::send_api_request(
1565            |builder| {
1566                let req = builder
1567                    .uri(url.as_ref())
1568                    .body(serde_json::to_string(&request)?.into());
1569                Ok(req?)
1570            },
1571            client,
1572            llm_token,
1573            app_version,
1574        )
1575        .await?;
1576
1577        #[cfg(feature = "eval-support")]
1578        if let Some((cache, request, key)) = cache_key {
1579            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1580        }
1581
1582        Ok((response, usage))
1583    }
1584
1585    fn handle_api_response<T>(
1586        this: &WeakEntity<Self>,
1587        response: Result<(T, Option<EditPredictionUsage>)>,
1588        cx: &mut gpui::AsyncApp,
1589    ) -> Result<T> {
1590        match response {
1591            Ok((data, usage)) => {
1592                if let Some(usage) = usage {
1593                    this.update(cx, |this, cx| {
1594                        this.user_store.update(cx, |user_store, cx| {
1595                            user_store.update_edit_prediction_usage(usage, cx);
1596                        });
1597                    })
1598                    .ok();
1599                }
1600                Ok(data)
1601            }
1602            Err(err) => {
1603                if err.is::<ZedUpdateRequiredError>() {
1604                    cx.update(|cx| {
1605                        this.update(cx, |this, _cx| {
1606                            this.update_required = true;
1607                        })
1608                        .ok();
1609
1610                        let error_message: SharedString = err.to_string().into();
1611                        show_app_notification(
1612                            NotificationId::unique::<ZedUpdateRequiredError>(),
1613                            cx,
1614                            move |cx| {
1615                                cx.new(|cx| {
1616                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1617                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1618                                })
1619                            },
1620                        );
1621                    })
1622                    .ok();
1623                }
1624                Err(err)
1625            }
1626        }
1627    }
1628
1629    async fn send_api_request<Res>(
1630        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1631        client: Arc<Client>,
1632        llm_token: LlmApiToken,
1633        app_version: Version,
1634    ) -> Result<(Res, Option<EditPredictionUsage>)>
1635    where
1636        Res: DeserializeOwned,
1637    {
1638        let http_client = client.http_client();
1639        let mut token = llm_token.acquire(&client).await?;
1640        let mut did_retry = false;
1641
1642        loop {
1643            let request_builder = http_client::Request::builder().method(Method::POST);
1644
1645            let request = build(
1646                request_builder
1647                    .header("Content-Type", "application/json")
1648                    .header("Authorization", format!("Bearer {}", token))
1649                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1650            )?;
1651
1652            let mut response = http_client.send(request).await?;
1653
1654            if let Some(minimum_required_version) = response
1655                .headers()
1656                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1657                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1658            {
1659                anyhow::ensure!(
1660                    app_version >= minimum_required_version,
1661                    ZedUpdateRequiredError {
1662                        minimum_version: minimum_required_version
1663                    }
1664                );
1665            }
1666
1667            if response.status().is_success() {
1668                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1669
1670                let mut body = Vec::new();
1671                response.body_mut().read_to_end(&mut body).await?;
1672                return Ok((serde_json::from_slice(&body)?, usage));
1673            } else if !did_retry
1674                && response
1675                    .headers()
1676                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1677                    .is_some()
1678            {
1679                did_retry = true;
1680                token = llm_token.refresh(&client).await?;
1681            } else {
1682                let mut body = String::new();
1683                response.body_mut().read_to_string(&mut body).await?;
1684                anyhow::bail!(
1685                    "Request failed with status: {:?}\nBody: {}",
1686                    response.status(),
1687                    body
1688                );
1689            }
1690        }
1691    }
1692
1693    pub fn refresh_context(
1694        &mut self,
1695        project: &Entity<Project>,
1696        buffer: &Entity<language::Buffer>,
1697        cursor_position: language::Anchor,
1698        cx: &mut Context<Self>,
1699    ) {
1700        if self.use_context {
1701            self.get_or_init_project(project, cx)
1702                .context
1703                .update(cx, |store, cx| {
1704                    store.refresh(buffer.clone(), cursor_position, cx);
1705                });
1706        }
1707    }
1708
1709    fn is_file_open_source(
1710        &self,
1711        project: &Entity<Project>,
1712        file: &Arc<dyn File>,
1713        cx: &App,
1714    ) -> bool {
1715        if !file.is_local() || file.is_private() {
1716            return false;
1717        }
1718        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1719            return false;
1720        };
1721        project_state
1722            .license_detection_watchers
1723            .get(&file.worktree_id(cx))
1724            .as_ref()
1725            .is_some_and(|watcher| watcher.is_project_open_source())
1726    }
1727
1728    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1729        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1730    }
1731
1732    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
1733        if !self.data_collection_choice.is_enabled() {
1734            return false;
1735        }
1736        events.iter().all(|event| {
1737            matches!(
1738                event.as_ref(),
1739                Event::BufferChange {
1740                    in_open_source_repo: true,
1741                    ..
1742                }
1743            )
1744        })
1745    }
1746
1747    fn load_data_collection_choice() -> DataCollectionChoice {
1748        let choice = KEY_VALUE_STORE
1749            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1750            .log_err()
1751            .flatten();
1752
1753        match choice.as_deref() {
1754            Some("true") => DataCollectionChoice::Enabled,
1755            Some("false") => DataCollectionChoice::Disabled,
1756            Some(_) => {
1757                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1758                DataCollectionChoice::NotAnswered
1759            }
1760            None => DataCollectionChoice::NotAnswered,
1761        }
1762    }
1763
1764    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1765        self.data_collection_choice = self.data_collection_choice.toggle();
1766        let new_choice = self.data_collection_choice;
1767        db::write_and_log(cx, move || {
1768            KEY_VALUE_STORE.write_kvp(
1769                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1770                new_choice.is_enabled().to_string(),
1771            )
1772        });
1773    }
1774
1775    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1776        self.shown_predictions.iter()
1777    }
1778
1779    pub fn shown_completions_len(&self) -> usize {
1780        self.shown_predictions.len()
1781    }
1782
1783    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1784        self.rated_predictions.contains(id)
1785    }
1786
1787    pub fn rate_prediction(
1788        &mut self,
1789        prediction: &EditPrediction,
1790        rating: EditPredictionRating,
1791        feedback: String,
1792        cx: &mut Context<Self>,
1793    ) {
1794        self.rated_predictions.insert(prediction.id.clone());
1795        telemetry::event!(
1796            "Edit Prediction Rated",
1797            rating,
1798            inputs = prediction.inputs,
1799            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1800            feedback
1801        );
1802        self.client.telemetry().flush_events().detach();
1803        cx.notify();
1804    }
1805
1806    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1807        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1808            && all_language_settings(None, cx).edit_predictions.use_context;
1809    }
1810}
1811
1812#[derive(Error, Debug)]
1813#[error(
1814    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1815)]
1816pub struct ZedUpdateRequiredError {
1817    minimum_version: Version,
1818}
1819
1820#[cfg(feature = "eval-support")]
1821pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1822
1823#[cfg(feature = "eval-support")]
1824#[derive(Debug, Clone, Copy, PartialEq)]
1825pub enum EvalCacheEntryKind {
1826    Context,
1827    Search,
1828    Prediction,
1829}
1830
1831#[cfg(feature = "eval-support")]
1832impl std::fmt::Display for EvalCacheEntryKind {
1833    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1834        match self {
1835            EvalCacheEntryKind::Search => write!(f, "search"),
1836            EvalCacheEntryKind::Context => write!(f, "context"),
1837            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1838        }
1839    }
1840}
1841
1842#[cfg(feature = "eval-support")]
1843pub trait EvalCache: Send + Sync {
1844    fn read(&self, key: EvalCacheKey) -> Option<String>;
1845    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1846}
1847
1848#[derive(Debug, Clone, Copy)]
1849pub enum DataCollectionChoice {
1850    NotAnswered,
1851    Enabled,
1852    Disabled,
1853}
1854
1855impl DataCollectionChoice {
1856    pub fn is_enabled(self) -> bool {
1857        match self {
1858            Self::Enabled => true,
1859            Self::NotAnswered | Self::Disabled => false,
1860        }
1861    }
1862
1863    pub fn is_answered(self) -> bool {
1864        match self {
1865            Self::Enabled | Self::Disabled => true,
1866            Self::NotAnswered => false,
1867        }
1868    }
1869
1870    #[must_use]
1871    pub fn toggle(&self) -> DataCollectionChoice {
1872        match self {
1873            Self::Enabled => Self::Disabled,
1874            Self::Disabled => Self::Enabled,
1875            Self::NotAnswered => Self::Enabled,
1876        }
1877    }
1878}
1879
1880impl From<bool> for DataCollectionChoice {
1881    fn from(value: bool) -> Self {
1882        match value {
1883            true => DataCollectionChoice::Enabled,
1884            false => DataCollectionChoice::Disabled,
1885        }
1886    }
1887}
1888
1889struct ZedPredictUpsell;
1890
1891impl Dismissable for ZedPredictUpsell {
1892    const KEY: &'static str = "dismissed-edit-predict-upsell";
1893
1894    fn dismissed() -> bool {
1895        // To make this backwards compatible with older versions of Zed, we
1896        // check if the user has seen the previous Edit Prediction Onboarding
1897        // before, by checking the data collection choice which was written to
1898        // the database once the user clicked on "Accept and Enable"
1899        if KEY_VALUE_STORE
1900            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1901            .log_err()
1902            .is_some_and(|s| s.is_some())
1903        {
1904            return true;
1905        }
1906
1907        KEY_VALUE_STORE
1908            .read_kvp(Self::KEY)
1909            .log_err()
1910            .is_some_and(|s| s.is_some())
1911    }
1912}
1913
1914pub fn should_show_upsell_modal() -> bool {
1915    !ZedPredictUpsell::dismissed()
1916}
1917
1918pub fn init(cx: &mut App) {
1919    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1920        workspace.register_action(
1921            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1922                ZedPredictModal::toggle(
1923                    workspace,
1924                    workspace.user_store().clone(),
1925                    workspace.client().clone(),
1926                    window,
1927                    cx,
1928                )
1929            },
1930        );
1931
1932        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
1933            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
1934                settings
1935                    .project
1936                    .all_languages
1937                    .features
1938                    .get_or_insert_default()
1939                    .edit_prediction_provider = Some(EditPredictionProvider::None)
1940            });
1941        });
1942    })
1943    .detach();
1944}