edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  12use collections::{HashMap, HashSet};
  13use db::kvp::{Dismissable, KEY_VALUE_STORE};
  14use edit_prediction_context::EditPredictionExcerptOptions;
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::{
  20        mpsc::{self, UnboundedReceiver},
  21        oneshot,
  22    },
  23    select_biased,
  24};
  25use gpui::BackgroundExecutor;
  26use gpui::{
  27    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  28    http_client::{self, AsyncBody, Method},
  29    prelude::*,
  30};
  31use language::language_settings::all_language_settings;
  32use language::{Anchor, Buffer, File, Point, ToPoint};
  33use language::{BufferSnapshot, OffsetRangeExt};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use project::{Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  40use std::collections::{VecDeque, hash_map};
  41use workspace::Workspace;
  42
  43use std::ops::Range;
  44use std::path::Path;
  45use std::rc::Rc;
  46use std::str::FromStr as _;
  47use std::sync::{Arc, LazyLock};
  48use std::time::{Duration, Instant};
  49use std::{env, mem};
  50use thiserror::Error;
  51use util::{RangeExt as _, ResultExt as _};
  52use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  53
  54mod license_detection;
  55mod onboarding_modal;
  56mod prediction;
  57pub mod sweep_ai;
  58pub mod udiff;
  59mod xml_edits;
  60mod zed_edit_prediction_delegate;
  61pub mod zeta1;
  62pub mod zeta2;
  63
  64#[cfg(test)]
  65mod edit_prediction_tests;
  66
  67use crate::license_detection::LicenseDetectionWatcher;
  68use crate::onboarding_modal::ZedPredictModal;
  69pub use crate::prediction::EditPrediction;
  70pub use crate::prediction::EditPredictionId;
  71pub use crate::prediction::EditPredictionInputs;
  72use crate::prediction::EditPredictionResult;
  73pub use crate::sweep_ai::SweepAi;
  74pub use telemetry_events::EditPredictionRating;
  75pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  76
  77actions!(
  78    edit_prediction,
  79    [
  80        /// Resets the edit prediction onboarding state.
  81        ResetOnboarding,
  82        /// Clears the edit prediction history.
  83        ClearHistory,
  84    ]
  85);
  86
  87/// Maximum number of events to track.
  88const EVENT_COUNT_MAX: usize = 6;
  89const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  90const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  91const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  92
  93pub struct SweepFeatureFlag;
  94
  95impl FeatureFlag for SweepFeatureFlag {
  96    const NAME: &str = "sweep-ai";
  97}
  98
  99pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 100    context: EditPredictionExcerptOptions {
 101        max_bytes: 512,
 102        min_bytes: 128,
 103        target_before_cursor_over_total_bytes: 0.5,
 104    },
 105    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 106    prompt_format: PromptFormat::DEFAULT,
 107};
 108
 109static USE_OLLAMA: LazyLock<bool> =
 110    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 111
 112static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 113    match env::var("ZED_ZETA2_MODEL").as_deref() {
 114        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 115        Ok(model) => model,
 116        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 117        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 118    }
 119    .to_string()
 120});
 121static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 122    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 123        if *USE_OLLAMA {
 124            Some("http://localhost:11434/v1/chat/completions".into())
 125        } else {
 126            None
 127        }
 128    })
 129});
 130
 131pub struct Zeta2FeatureFlag;
 132
 133impl FeatureFlag for Zeta2FeatureFlag {
 134    const NAME: &'static str = "zeta2";
 135
 136    fn enabled_for_staff() -> bool {
 137        true
 138    }
 139}
 140
 141#[derive(Clone)]
 142struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 143
 144impl Global for EditPredictionStoreGlobal {}
 145
 146pub struct EditPredictionStore {
 147    client: Arc<Client>,
 148    user_store: Entity<UserStore>,
 149    llm_token: LlmApiToken,
 150    _llm_token_subscription: Subscription,
 151    projects: HashMap<EntityId, ProjectState>,
 152    use_context: bool,
 153    options: ZetaOptions,
 154    update_required: bool,
 155    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 156    #[cfg(feature = "eval-support")]
 157    eval_cache: Option<Arc<dyn EvalCache>>,
 158    edit_prediction_model: EditPredictionModel,
 159    pub sweep_ai: SweepAi,
 160    data_collection_choice: DataCollectionChoice,
 161    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 162    shown_predictions: VecDeque<EditPrediction>,
 163    rated_predictions: HashSet<EditPredictionId>,
 164}
 165
 166#[derive(Copy, Clone, Default, PartialEq, Eq)]
 167pub enum EditPredictionModel {
 168    #[default]
 169    Zeta1,
 170    Zeta2,
 171    Sweep,
 172}
 173
 174#[derive(Debug, Clone, PartialEq)]
 175pub struct ZetaOptions {
 176    pub context: EditPredictionExcerptOptions,
 177    pub max_prompt_bytes: usize,
 178    pub prompt_format: predict_edits_v3::PromptFormat,
 179}
 180
 181#[derive(Debug)]
 182pub enum DebugEvent {
 183    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 184    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 185    EditPredictionRequested(EditPredictionRequestedDebugEvent),
 186}
 187
 188#[derive(Debug)]
 189pub struct ContextRetrievalStartedDebugEvent {
 190    pub project_entity_id: EntityId,
 191    pub timestamp: Instant,
 192    pub search_prompt: String,
 193}
 194
 195#[derive(Debug)]
 196pub struct ContextRetrievalFinishedDebugEvent {
 197    pub project_entity_id: EntityId,
 198    pub timestamp: Instant,
 199    pub metadata: Vec<(&'static str, SharedString)>,
 200}
 201
 202#[derive(Debug)]
 203pub struct EditPredictionRequestedDebugEvent {
 204    pub inputs: EditPredictionInputs,
 205    pub retrieval_time: Duration,
 206    pub buffer: WeakEntity<Buffer>,
 207    pub position: Anchor,
 208    pub local_prompt: Result<String, String>,
 209    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 210}
 211
 212pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 213
 214struct ProjectState {
 215    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 216    last_event: Option<LastEvent>,
 217    recent_paths: VecDeque<ProjectPath>,
 218    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 219    current_prediction: Option<CurrentEditPrediction>,
 220    next_pending_prediction_id: usize,
 221    pending_predictions: ArrayVec<PendingPrediction, 2>,
 222    context_updates_tx: smol::channel::Sender<()>,
 223    context_updates_rx: smol::channel::Receiver<()>,
 224    last_prediction_refresh: Option<(EntityId, Instant)>,
 225    cancelled_predictions: HashSet<usize>,
 226    context: Entity<RelatedExcerptStore>,
 227    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 228    _subscription: gpui::Subscription,
 229}
 230
 231impl ProjectState {
 232    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 233        self.events
 234            .iter()
 235            .cloned()
 236            .chain(
 237                self.last_event
 238                    .as_ref()
 239                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 240            )
 241            .collect()
 242    }
 243
 244    fn cancel_pending_prediction(
 245        &mut self,
 246        pending_prediction: PendingPrediction,
 247        cx: &mut Context<EditPredictionStore>,
 248    ) {
 249        self.cancelled_predictions.insert(pending_prediction.id);
 250
 251        cx.spawn(async move |this, cx| {
 252            let Some(prediction_id) = pending_prediction.task.await else {
 253                return;
 254            };
 255
 256            this.update(cx, |this, _cx| {
 257                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 258            })
 259            .ok();
 260        })
 261        .detach()
 262    }
 263}
 264
 265#[derive(Debug, Clone)]
 266struct CurrentEditPrediction {
 267    pub requested_by: PredictionRequestedBy,
 268    pub prediction: EditPrediction,
 269    pub was_shown: bool,
 270}
 271
 272impl CurrentEditPrediction {
 273    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 274        let Some(new_edits) = self
 275            .prediction
 276            .interpolate(&self.prediction.buffer.read(cx))
 277        else {
 278            return false;
 279        };
 280
 281        if self.prediction.buffer != old_prediction.prediction.buffer {
 282            return true;
 283        }
 284
 285        let Some(old_edits) = old_prediction
 286            .prediction
 287            .interpolate(&old_prediction.prediction.buffer.read(cx))
 288        else {
 289            return true;
 290        };
 291
 292        let requested_by_buffer_id = self.requested_by.buffer_id();
 293
 294        // This reduces the occurrence of UI thrash from replacing edits
 295        //
 296        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 297        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 298            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 299            && old_edits.len() == 1
 300            && new_edits.len() == 1
 301        {
 302            let (old_range, old_text) = &old_edits[0];
 303            let (new_range, new_text) = &new_edits[0];
 304            new_range == old_range && new_text.starts_with(old_text.as_ref())
 305        } else {
 306            true
 307        }
 308    }
 309}
 310
 311#[derive(Debug, Clone)]
 312enum PredictionRequestedBy {
 313    DiagnosticsUpdate,
 314    Buffer(EntityId),
 315}
 316
 317impl PredictionRequestedBy {
 318    pub fn buffer_id(&self) -> Option<EntityId> {
 319        match self {
 320            PredictionRequestedBy::DiagnosticsUpdate => None,
 321            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 322        }
 323    }
 324}
 325
 326#[derive(Debug)]
 327struct PendingPrediction {
 328    id: usize,
 329    task: Task<Option<EditPredictionId>>,
 330}
 331
 332/// A prediction from the perspective of a buffer.
 333#[derive(Debug)]
 334enum BufferEditPrediction<'a> {
 335    Local { prediction: &'a EditPrediction },
 336    Jump { prediction: &'a EditPrediction },
 337}
 338
 339#[cfg(test)]
 340impl std::ops::Deref for BufferEditPrediction<'_> {
 341    type Target = EditPrediction;
 342
 343    fn deref(&self) -> &Self::Target {
 344        match self {
 345            BufferEditPrediction::Local { prediction } => prediction,
 346            BufferEditPrediction::Jump { prediction } => prediction,
 347        }
 348    }
 349}
 350
 351struct RegisteredBuffer {
 352    snapshot: BufferSnapshot,
 353    _subscriptions: [gpui::Subscription; 2],
 354}
 355
 356struct LastEvent {
 357    old_snapshot: BufferSnapshot,
 358    new_snapshot: BufferSnapshot,
 359    end_edit_anchor: Option<Anchor>,
 360}
 361
 362impl LastEvent {
 363    pub fn finalize(
 364        &self,
 365        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 366        cx: &App,
 367    ) -> Option<Arc<predict_edits_v3::Event>> {
 368        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 369        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 370
 371        let file = self.new_snapshot.file();
 372        let old_file = self.old_snapshot.file();
 373
 374        let in_open_source_repo = [file, old_file].iter().all(|file| {
 375            file.is_some_and(|file| {
 376                license_detection_watchers
 377                    .get(&file.worktree_id(cx))
 378                    .is_some_and(|watcher| watcher.is_project_open_source())
 379            })
 380        });
 381
 382        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 383
 384        if path == old_path && diff.is_empty() {
 385            None
 386        } else {
 387            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 388                old_path,
 389                path,
 390                diff,
 391                in_open_source_repo,
 392                // TODO: Actually detect if this edit was predicted or not
 393                predicted: false,
 394            }))
 395        }
 396    }
 397}
 398
 399fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 400    if let Some(file) = snapshot.file() {
 401        file.full_path(cx).into()
 402    } else {
 403        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 404    }
 405}
 406
 407impl EditPredictionStore {
 408    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 409        cx.try_global::<EditPredictionStoreGlobal>()
 410            .map(|global| global.0.clone())
 411    }
 412
 413    pub fn global(
 414        client: &Arc<Client>,
 415        user_store: &Entity<UserStore>,
 416        cx: &mut App,
 417    ) -> Entity<Self> {
 418        cx.try_global::<EditPredictionStoreGlobal>()
 419            .map(|global| global.0.clone())
 420            .unwrap_or_else(|| {
 421                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 422                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 423                ep_store
 424            })
 425    }
 426
 427    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 428        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 429        let data_collection_choice = Self::load_data_collection_choice();
 430
 431        let llm_token = LlmApiToken::default();
 432
 433        let (reject_tx, reject_rx) = mpsc::unbounded();
 434        cx.background_spawn({
 435            let client = client.clone();
 436            let llm_token = llm_token.clone();
 437            let app_version = AppVersion::global(cx);
 438            let background_executor = cx.background_executor().clone();
 439            async move {
 440                Self::handle_rejected_predictions(
 441                    reject_rx,
 442                    client,
 443                    llm_token,
 444                    app_version,
 445                    background_executor,
 446                )
 447                .await
 448            }
 449        })
 450        .detach();
 451
 452        let mut this = Self {
 453            projects: HashMap::default(),
 454            client,
 455            user_store,
 456            options: DEFAULT_OPTIONS,
 457            use_context: false,
 458            llm_token,
 459            _llm_token_subscription: cx.subscribe(
 460                &refresh_llm_token_listener,
 461                |this, _listener, _event, cx| {
 462                    let client = this.client.clone();
 463                    let llm_token = this.llm_token.clone();
 464                    cx.spawn(async move |_this, _cx| {
 465                        llm_token.refresh(&client).await?;
 466                        anyhow::Ok(())
 467                    })
 468                    .detach_and_log_err(cx);
 469                },
 470            ),
 471            update_required: false,
 472            debug_tx: None,
 473            #[cfg(feature = "eval-support")]
 474            eval_cache: None,
 475            edit_prediction_model: EditPredictionModel::Zeta2,
 476            sweep_ai: SweepAi::new(cx),
 477            data_collection_choice,
 478            reject_predictions_tx: reject_tx,
 479            rated_predictions: Default::default(),
 480            shown_predictions: Default::default(),
 481        };
 482
 483        this.configure_context_retrieval(cx);
 484        let weak_this = cx.weak_entity();
 485        cx.on_flags_ready(move |_, cx| {
 486            weak_this
 487                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 488                .ok();
 489        })
 490        .detach();
 491        cx.observe_global::<SettingsStore>(|this, cx| {
 492            this.configure_context_retrieval(cx);
 493        })
 494        .detach();
 495
 496        this
 497    }
 498
 499    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 500        self.edit_prediction_model = model;
 501    }
 502
 503    pub fn has_sweep_api_token(&self) -> bool {
 504        self.sweep_ai
 505            .api_token
 506            .clone()
 507            .now_or_never()
 508            .flatten()
 509            .is_some()
 510    }
 511
 512    #[cfg(feature = "eval-support")]
 513    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 514        self.eval_cache = Some(cache);
 515    }
 516
 517    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
 518        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 519        self.debug_tx = Some(debug_watch_tx);
 520        debug_watch_rx
 521    }
 522
 523    pub fn options(&self) -> &ZetaOptions {
 524        &self.options
 525    }
 526
 527    pub fn set_options(&mut self, options: ZetaOptions) {
 528        self.options = options;
 529    }
 530
 531    pub fn set_use_context(&mut self, use_context: bool) {
 532        self.use_context = use_context;
 533    }
 534
 535    pub fn clear_history(&mut self) {
 536        for project_state in self.projects.values_mut() {
 537            project_state.events.clear();
 538        }
 539    }
 540
 541    pub fn context_for_project<'a>(
 542        &'a self,
 543        project: &Entity<Project>,
 544        cx: &'a App,
 545    ) -> &'a [RelatedFile] {
 546        self.projects
 547            .get(&project.entity_id())
 548            .map(|project| project.context.read(cx).related_files())
 549            .unwrap_or(&[])
 550    }
 551
 552    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 553        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 554            self.user_store.read(cx).edit_prediction_usage()
 555        } else {
 556            None
 557        }
 558    }
 559
 560    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 561        self.get_or_init_project(project, cx);
 562    }
 563
 564    pub fn register_buffer(
 565        &mut self,
 566        buffer: &Entity<Buffer>,
 567        project: &Entity<Project>,
 568        cx: &mut Context<Self>,
 569    ) {
 570        let project_state = self.get_or_init_project(project, cx);
 571        Self::register_buffer_impl(project_state, buffer, project, cx);
 572    }
 573
 574    fn get_or_init_project(
 575        &mut self,
 576        project: &Entity<Project>,
 577        cx: &mut Context<Self>,
 578    ) -> &mut ProjectState {
 579        let entity_id = project.entity_id();
 580        let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
 581        self.projects
 582            .entry(entity_id)
 583            .or_insert_with(|| ProjectState {
 584                context: {
 585                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 586                    cx.subscribe(
 587                        &related_excerpt_store,
 588                        move |this, _, event, _| match event {
 589                            RelatedExcerptStoreEvent::StartedRefresh => {
 590                                if let Some(debug_tx) = this.debug_tx.clone() {
 591                                    debug_tx
 592                                        .unbounded_send(DebugEvent::ContextRetrievalStarted(
 593                                            ContextRetrievalStartedDebugEvent {
 594                                                project_entity_id: entity_id,
 595                                                timestamp: Instant::now(),
 596                                                search_prompt: String::new(),
 597                                            },
 598                                        ))
 599                                        .ok();
 600                                }
 601                            }
 602                            RelatedExcerptStoreEvent::FinishedRefresh {
 603                                cache_hit_count,
 604                                cache_miss_count,
 605                                mean_definition_latency,
 606                                max_definition_latency,
 607                            } => {
 608                                if let Some(debug_tx) = this.debug_tx.clone() {
 609                                    debug_tx
 610                                        .unbounded_send(DebugEvent::ContextRetrievalFinished(
 611                                            ContextRetrievalFinishedDebugEvent {
 612                                                project_entity_id: entity_id,
 613                                                timestamp: Instant::now(),
 614                                                metadata: vec![
 615                                                    (
 616                                                        "Cache Hits",
 617                                                        format!(
 618                                                            "{}/{}",
 619                                                            cache_hit_count,
 620                                                            cache_hit_count + cache_miss_count
 621                                                        )
 622                                                        .into(),
 623                                                    ),
 624                                                    (
 625                                                        "Max LSP Time",
 626                                                        format!(
 627                                                            "{} ms",
 628                                                            max_definition_latency.as_millis()
 629                                                        )
 630                                                        .into(),
 631                                                    ),
 632                                                    (
 633                                                        "Mean LSP Time",
 634                                                        format!(
 635                                                            "{} ms",
 636                                                            mean_definition_latency.as_millis()
 637                                                        )
 638                                                        .into(),
 639                                                    ),
 640                                                ],
 641                                            },
 642                                        ))
 643                                        .ok();
 644                                }
 645                                if let Some(project_state) = this.projects.get(&entity_id) {
 646                                    project_state.context_updates_tx.send_blocking(()).ok();
 647                                }
 648                            }
 649                        },
 650                    )
 651                    .detach();
 652                    related_excerpt_store
 653                },
 654                events: VecDeque::new(),
 655                last_event: None,
 656                recent_paths: VecDeque::new(),
 657                context_updates_rx,
 658                context_updates_tx,
 659                registered_buffers: HashMap::default(),
 660                current_prediction: None,
 661                cancelled_predictions: HashSet::default(),
 662                pending_predictions: ArrayVec::new(),
 663                next_pending_prediction_id: 0,
 664                last_prediction_refresh: None,
 665                license_detection_watchers: HashMap::default(),
 666                _subscription: cx.subscribe(&project, Self::handle_project_event),
 667            })
 668    }
 669
 670    pub fn project_context_updates(
 671        &self,
 672        project: &Entity<Project>,
 673    ) -> Option<smol::channel::Receiver<()>> {
 674        let project_state = self.projects.get(&project.entity_id())?;
 675        Some(project_state.context_updates_rx.clone())
 676    }
 677
 678    fn handle_project_event(
 679        &mut self,
 680        project: Entity<Project>,
 681        event: &project::Event,
 682        cx: &mut Context<Self>,
 683    ) {
 684        // TODO [zeta2] init with recent paths
 685        match event {
 686            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 687                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 688                    return;
 689                };
 690                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 691                if let Some(path) = path {
 692                    if let Some(ix) = project_state
 693                        .recent_paths
 694                        .iter()
 695                        .position(|probe| probe == &path)
 696                    {
 697                        project_state.recent_paths.remove(ix);
 698                    }
 699                    project_state.recent_paths.push_front(path);
 700                }
 701            }
 702            project::Event::DiagnosticsUpdated { .. } => {
 703                if cx.has_flag::<Zeta2FeatureFlag>() {
 704                    self.refresh_prediction_from_diagnostics(project, cx);
 705                }
 706            }
 707            _ => (),
 708        }
 709    }
 710
 711    fn register_buffer_impl<'a>(
 712        project_state: &'a mut ProjectState,
 713        buffer: &Entity<Buffer>,
 714        project: &Entity<Project>,
 715        cx: &mut Context<Self>,
 716    ) -> &'a mut RegisteredBuffer {
 717        let buffer_id = buffer.entity_id();
 718
 719        if let Some(file) = buffer.read(cx).file() {
 720            let worktree_id = file.worktree_id(cx);
 721            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 722                project_state
 723                    .license_detection_watchers
 724                    .entry(worktree_id)
 725                    .or_insert_with(|| {
 726                        let project_entity_id = project.entity_id();
 727                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 728                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 729                            else {
 730                                return;
 731                            };
 732                            project_state
 733                                .license_detection_watchers
 734                                .remove(&worktree_id);
 735                        })
 736                        .detach();
 737                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 738                    });
 739            }
 740        }
 741
 742        match project_state.registered_buffers.entry(buffer_id) {
 743            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 744            hash_map::Entry::Vacant(entry) => {
 745                let snapshot = buffer.read(cx).snapshot();
 746                let project_entity_id = project.entity_id();
 747                entry.insert(RegisteredBuffer {
 748                    snapshot,
 749                    _subscriptions: [
 750                        cx.subscribe(buffer, {
 751                            let project = project.downgrade();
 752                            move |this, buffer, event, cx| {
 753                                if let language::BufferEvent::Edited = event
 754                                    && let Some(project) = project.upgrade()
 755                                {
 756                                    this.report_changes_for_buffer(&buffer, &project, cx);
 757                                }
 758                            }
 759                        }),
 760                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 761                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 762                            else {
 763                                return;
 764                            };
 765                            project_state.registered_buffers.remove(&buffer_id);
 766                        }),
 767                    ],
 768                })
 769            }
 770        }
 771    }
 772
 773    fn report_changes_for_buffer(
 774        &mut self,
 775        buffer: &Entity<Buffer>,
 776        project: &Entity<Project>,
 777        cx: &mut Context<Self>,
 778    ) {
 779        let project_state = self.get_or_init_project(project, cx);
 780        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 781
 782        let new_snapshot = buffer.read(cx).snapshot();
 783        if new_snapshot.version == registered_buffer.snapshot.version {
 784            return;
 785        }
 786
 787        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 788        let end_edit_anchor = new_snapshot
 789            .anchored_edits_since::<Point>(&old_snapshot.version)
 790            .last()
 791            .map(|(_, range)| range.end);
 792        let events = &mut project_state.events;
 793
 794        if let Some(LastEvent {
 795            new_snapshot: last_new_snapshot,
 796            end_edit_anchor: last_end_edit_anchor,
 797            ..
 798        }) = project_state.last_event.as_mut()
 799        {
 800            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 801                == last_new_snapshot.remote_id()
 802                && old_snapshot.version == last_new_snapshot.version;
 803
 804            let should_coalesce = is_next_snapshot_of_same_buffer
 805                && end_edit_anchor
 806                    .as_ref()
 807                    .zip(last_end_edit_anchor.as_ref())
 808                    .is_some_and(|(a, b)| {
 809                        let a = a.to_point(&new_snapshot);
 810                        let b = b.to_point(&new_snapshot);
 811                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 812                    });
 813
 814            if should_coalesce {
 815                *last_end_edit_anchor = end_edit_anchor;
 816                *last_new_snapshot = new_snapshot;
 817                return;
 818            }
 819        }
 820
 821        if events.len() + 1 >= EVENT_COUNT_MAX {
 822            events.pop_front();
 823        }
 824
 825        if let Some(event) = project_state.last_event.take() {
 826            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 827        }
 828
 829        project_state.last_event = Some(LastEvent {
 830            old_snapshot,
 831            new_snapshot,
 832            end_edit_anchor,
 833        });
 834    }
 835
 836    fn current_prediction_for_buffer(
 837        &self,
 838        buffer: &Entity<Buffer>,
 839        project: &Entity<Project>,
 840        cx: &App,
 841    ) -> Option<BufferEditPrediction<'_>> {
 842        let project_state = self.projects.get(&project.entity_id())?;
 843
 844        let CurrentEditPrediction {
 845            requested_by,
 846            prediction,
 847            ..
 848        } = project_state.current_prediction.as_ref()?;
 849
 850        if prediction.targets_buffer(buffer.read(cx)) {
 851            Some(BufferEditPrediction::Local { prediction })
 852        } else {
 853            let show_jump = match requested_by {
 854                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 855                    requested_by_buffer_id == &buffer.entity_id()
 856                }
 857                PredictionRequestedBy::DiagnosticsUpdate => true,
 858            };
 859
 860            if show_jump {
 861                Some(BufferEditPrediction::Jump { prediction })
 862            } else {
 863                None
 864            }
 865        }
 866    }
 867
 868    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 869        match self.edit_prediction_model {
 870            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
 871            EditPredictionModel::Sweep => return,
 872        }
 873
 874        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 875            return;
 876        };
 877
 878        let Some(prediction) = project_state.current_prediction.take() else {
 879            return;
 880        };
 881        let request_id = prediction.prediction.id.to_string();
 882        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 883            project_state.cancel_pending_prediction(pending_prediction, cx);
 884        }
 885
 886        let client = self.client.clone();
 887        let llm_token = self.llm_token.clone();
 888        let app_version = AppVersion::global(cx);
 889        cx.spawn(async move |this, cx| {
 890            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 891                http_client::Url::parse(&predict_edits_url)?
 892            } else {
 893                client
 894                    .http_client()
 895                    .build_zed_llm_url("/predict_edits/accept", &[])?
 896            };
 897
 898            let response = cx
 899                .background_spawn(Self::send_api_request::<()>(
 900                    move |builder| {
 901                        let req = builder.uri(url.as_ref()).body(
 902                            serde_json::to_string(&AcceptEditPredictionBody {
 903                                request_id: request_id.clone(),
 904                            })?
 905                            .into(),
 906                        );
 907                        Ok(req?)
 908                    },
 909                    client,
 910                    llm_token,
 911                    app_version,
 912                ))
 913                .await;
 914
 915            Self::handle_api_response(&this, response, cx)?;
 916            anyhow::Ok(())
 917        })
 918        .detach_and_log_err(cx);
 919    }
 920
 921    async fn handle_rejected_predictions(
 922        rx: UnboundedReceiver<EditPredictionRejection>,
 923        client: Arc<Client>,
 924        llm_token: LlmApiToken,
 925        app_version: Version,
 926        background_executor: BackgroundExecutor,
 927    ) {
 928        let mut rx = std::pin::pin!(rx.peekable());
 929        let mut batched = Vec::new();
 930
 931        while let Some(rejection) = rx.next().await {
 932            batched.push(rejection);
 933
 934            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
 935                select_biased! {
 936                    next = rx.as_mut().peek().fuse() => {
 937                        if next.is_some() {
 938                            continue;
 939                        }
 940                    }
 941                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
 942                }
 943            }
 944
 945            let url = client
 946                .http_client()
 947                .build_zed_llm_url("/predict_edits/reject", &[])
 948                .unwrap();
 949
 950            let flush_count = batched
 951                .len()
 952                // in case items have accumulated after failure
 953                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
 954            let start = batched.len() - flush_count;
 955
 956            let body = RejectEditPredictionsBodyRef {
 957                rejections: &batched[start..],
 958            };
 959
 960            let result = Self::send_api_request::<()>(
 961                |builder| {
 962                    let req = builder
 963                        .uri(url.as_ref())
 964                        .body(serde_json::to_string(&body)?.into());
 965                    anyhow::Ok(req?)
 966                },
 967                client.clone(),
 968                llm_token.clone(),
 969                app_version.clone(),
 970            )
 971            .await;
 972
 973            if result.log_err().is_some() {
 974                batched.drain(start..);
 975            }
 976        }
 977    }
 978
 979    fn reject_current_prediction(
 980        &mut self,
 981        reason: EditPredictionRejectReason,
 982        project: &Entity<Project>,
 983    ) {
 984        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 985            project_state.pending_predictions.clear();
 986            if let Some(prediction) = project_state.current_prediction.take() {
 987                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
 988            }
 989        };
 990    }
 991
 992    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 993        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 994            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 995                if !current_prediction.was_shown {
 996                    current_prediction.was_shown = true;
 997                    self.shown_predictions
 998                        .push_front(current_prediction.prediction.clone());
 999                    if self.shown_predictions.len() > 50 {
1000                        let completion = self.shown_predictions.pop_back().unwrap();
1001                        self.rated_predictions.remove(&completion.id);
1002                    }
1003                }
1004            }
1005        }
1006    }
1007
1008    fn reject_prediction(
1009        &mut self,
1010        prediction_id: EditPredictionId,
1011        reason: EditPredictionRejectReason,
1012        was_shown: bool,
1013    ) {
1014        match self.edit_prediction_model {
1015            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1016            EditPredictionModel::Sweep => return,
1017        }
1018
1019        self.reject_predictions_tx
1020            .unbounded_send(EditPredictionRejection {
1021                request_id: prediction_id.to_string(),
1022                reason,
1023                was_shown,
1024            })
1025            .log_err();
1026    }
1027
1028    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1029        self.projects
1030            .get(&project.entity_id())
1031            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1032    }
1033
1034    pub fn refresh_prediction_from_buffer(
1035        &mut self,
1036        project: Entity<Project>,
1037        buffer: Entity<Buffer>,
1038        position: language::Anchor,
1039        cx: &mut Context<Self>,
1040    ) {
1041        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1042            let Some(request_task) = this
1043                .update(cx, |this, cx| {
1044                    this.request_prediction(
1045                        &project,
1046                        &buffer,
1047                        position,
1048                        PredictEditsRequestTrigger::Other,
1049                        cx,
1050                    )
1051                })
1052                .log_err()
1053            else {
1054                return Task::ready(anyhow::Ok(None));
1055            };
1056
1057            cx.spawn(async move |_cx| {
1058                request_task.await.map(|prediction_result| {
1059                    prediction_result.map(|prediction_result| {
1060                        (
1061                            prediction_result,
1062                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1063                        )
1064                    })
1065                })
1066            })
1067        })
1068    }
1069
1070    pub fn refresh_prediction_from_diagnostics(
1071        &mut self,
1072        project: Entity<Project>,
1073        cx: &mut Context<Self>,
1074    ) {
1075        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1076            return;
1077        };
1078
1079        // Prefer predictions from buffer
1080        if project_state.current_prediction.is_some() {
1081            return;
1082        };
1083
1084        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1085            let Some(open_buffer_task) = project
1086                .update(cx, |project, cx| {
1087                    project
1088                        .active_entry()
1089                        .and_then(|entry| project.path_for_entry(entry, cx))
1090                        .map(|path| project.open_buffer(path, cx))
1091                })
1092                .log_err()
1093                .flatten()
1094            else {
1095                return Task::ready(anyhow::Ok(None));
1096            };
1097
1098            cx.spawn(async move |cx| {
1099                let active_buffer = open_buffer_task.await?;
1100                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1101
1102                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1103                    active_buffer,
1104                    &snapshot,
1105                    Default::default(),
1106                    Default::default(),
1107                    &project,
1108                    cx,
1109                )
1110                .await?
1111                else {
1112                    return anyhow::Ok(None);
1113                };
1114
1115                let Some(prediction_result) = this
1116                    .update(cx, |this, cx| {
1117                        this.request_prediction(
1118                            &project,
1119                            &jump_buffer,
1120                            jump_position,
1121                            PredictEditsRequestTrigger::Diagnostics,
1122                            cx,
1123                        )
1124                    })?
1125                    .await?
1126                else {
1127                    return anyhow::Ok(None);
1128                };
1129
1130                this.update(cx, |this, cx| {
1131                    Some((
1132                        if this
1133                            .get_or_init_project(&project, cx)
1134                            .current_prediction
1135                            .is_none()
1136                        {
1137                            prediction_result
1138                        } else {
1139                            EditPredictionResult {
1140                                id: prediction_result.id,
1141                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1142                            }
1143                        },
1144                        PredictionRequestedBy::DiagnosticsUpdate,
1145                    ))
1146                })
1147            })
1148        });
1149    }
1150
1151    #[cfg(not(test))]
1152    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1153    #[cfg(test)]
1154    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1155
1156    fn queue_prediction_refresh(
1157        &mut self,
1158        project: Entity<Project>,
1159        throttle_entity: EntityId,
1160        cx: &mut Context<Self>,
1161        do_refresh: impl FnOnce(
1162            WeakEntity<Self>,
1163            &mut AsyncApp,
1164        )
1165            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1166        + 'static,
1167    ) {
1168        let project_state = self.get_or_init_project(&project, cx);
1169        let pending_prediction_id = project_state.next_pending_prediction_id;
1170        project_state.next_pending_prediction_id += 1;
1171        let last_request = project_state.last_prediction_refresh;
1172
1173        let task = cx.spawn(async move |this, cx| {
1174            if let Some((last_entity, last_timestamp)) = last_request
1175                && throttle_entity == last_entity
1176                && let Some(timeout) =
1177                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1178            {
1179                cx.background_executor().timer(timeout).await;
1180            }
1181
1182            // If this task was cancelled before the throttle timeout expired,
1183            // do not perform a request.
1184            let mut is_cancelled = true;
1185            this.update(cx, |this, cx| {
1186                let project_state = this.get_or_init_project(&project, cx);
1187                if !project_state
1188                    .cancelled_predictions
1189                    .remove(&pending_prediction_id)
1190                {
1191                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1192                    is_cancelled = false;
1193                }
1194            })
1195            .ok();
1196            if is_cancelled {
1197                return None;
1198            }
1199
1200            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1201            let new_prediction_id = new_prediction_result
1202                .as_ref()
1203                .map(|(prediction, _)| prediction.id.clone());
1204
1205            // When a prediction completes, remove it from the pending list, and cancel
1206            // any pending predictions that were enqueued before it.
1207            this.update(cx, |this, cx| {
1208                let project_state = this.get_or_init_project(&project, cx);
1209
1210                let is_cancelled = project_state
1211                    .cancelled_predictions
1212                    .remove(&pending_prediction_id);
1213
1214                let new_current_prediction = if !is_cancelled
1215                    && let Some((prediction_result, requested_by)) = new_prediction_result
1216                {
1217                    match prediction_result.prediction {
1218                        Ok(prediction) => {
1219                            let new_prediction = CurrentEditPrediction {
1220                                requested_by,
1221                                prediction,
1222                                was_shown: false,
1223                            };
1224
1225                            if let Some(current_prediction) =
1226                                project_state.current_prediction.as_ref()
1227                            {
1228                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1229                                {
1230                                    this.reject_current_prediction(
1231                                        EditPredictionRejectReason::Replaced,
1232                                        &project,
1233                                    );
1234
1235                                    Some(new_prediction)
1236                                } else {
1237                                    this.reject_prediction(
1238                                        new_prediction.prediction.id,
1239                                        EditPredictionRejectReason::CurrentPreferred,
1240                                        false,
1241                                    );
1242                                    None
1243                                }
1244                            } else {
1245                                Some(new_prediction)
1246                            }
1247                        }
1248                        Err(reject_reason) => {
1249                            this.reject_prediction(prediction_result.id, reject_reason, false);
1250                            None
1251                        }
1252                    }
1253                } else {
1254                    None
1255                };
1256
1257                let project_state = this.get_or_init_project(&project, cx);
1258
1259                if let Some(new_prediction) = new_current_prediction {
1260                    project_state.current_prediction = Some(new_prediction);
1261                }
1262
1263                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1264                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1265                    if pending_prediction.id == pending_prediction_id {
1266                        pending_predictions.remove(ix);
1267                        for pending_prediction in pending_predictions.drain(0..ix) {
1268                            project_state.cancel_pending_prediction(pending_prediction, cx)
1269                        }
1270                        break;
1271                    }
1272                }
1273                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1274                cx.notify();
1275            })
1276            .ok();
1277
1278            new_prediction_id
1279        });
1280
1281        if project_state.pending_predictions.len() <= 1 {
1282            project_state.pending_predictions.push(PendingPrediction {
1283                id: pending_prediction_id,
1284                task,
1285            });
1286        } else if project_state.pending_predictions.len() == 2 {
1287            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1288            project_state.pending_predictions.push(PendingPrediction {
1289                id: pending_prediction_id,
1290                task,
1291            });
1292            project_state.cancel_pending_prediction(pending_prediction, cx);
1293        }
1294    }
1295
1296    pub fn request_prediction(
1297        &mut self,
1298        project: &Entity<Project>,
1299        active_buffer: &Entity<Buffer>,
1300        position: language::Anchor,
1301        trigger: PredictEditsRequestTrigger,
1302        cx: &mut Context<Self>,
1303    ) -> Task<Result<Option<EditPredictionResult>>> {
1304        self.request_prediction_internal(
1305            project.clone(),
1306            active_buffer.clone(),
1307            position,
1308            trigger,
1309            cx.has_flag::<Zeta2FeatureFlag>(),
1310            cx,
1311        )
1312    }
1313
1314    fn request_prediction_internal(
1315        &mut self,
1316        project: Entity<Project>,
1317        active_buffer: Entity<Buffer>,
1318        position: language::Anchor,
1319        trigger: PredictEditsRequestTrigger,
1320        allow_jump: bool,
1321        cx: &mut Context<Self>,
1322    ) -> Task<Result<Option<EditPredictionResult>>> {
1323        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1324
1325        self.get_or_init_project(&project, cx);
1326        let project_state = self.projects.get(&project.entity_id()).unwrap();
1327        let events = project_state.events(cx);
1328        let has_events = !events.is_empty();
1329
1330        let snapshot = active_buffer.read(cx).snapshot();
1331        let cursor_point = position.to_point(&snapshot);
1332        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1333        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1334        let diagnostic_search_range =
1335            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1336
1337        let related_files = if self.use_context {
1338            self.context_for_project(&project, cx).to_vec()
1339        } else {
1340            Vec::new()
1341        };
1342
1343        let task = match self.edit_prediction_model {
1344            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
1345                self,
1346                &project,
1347                &active_buffer,
1348                snapshot.clone(),
1349                position,
1350                events,
1351                trigger,
1352                cx,
1353            ),
1354            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1355                self,
1356                &project,
1357                &active_buffer,
1358                snapshot.clone(),
1359                position,
1360                events,
1361                related_files,
1362                trigger,
1363                cx,
1364            ),
1365            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1366                &project,
1367                &active_buffer,
1368                snapshot.clone(),
1369                position,
1370                events,
1371                &project_state.recent_paths,
1372                related_files,
1373                diagnostic_search_range.clone(),
1374                cx,
1375            ),
1376        };
1377
1378        cx.spawn(async move |this, cx| {
1379            let prediction = task.await?;
1380
1381            if prediction.is_none() && allow_jump {
1382                let cursor_point = position.to_point(&snapshot);
1383                if has_events
1384                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1385                        active_buffer.clone(),
1386                        &snapshot,
1387                        diagnostic_search_range,
1388                        cursor_point,
1389                        &project,
1390                        cx,
1391                    )
1392                    .await?
1393                {
1394                    return this
1395                        .update(cx, |this, cx| {
1396                            this.request_prediction_internal(
1397                                project,
1398                                jump_buffer,
1399                                jump_position,
1400                                trigger,
1401                                false,
1402                                cx,
1403                            )
1404                        })?
1405                        .await;
1406                }
1407
1408                return anyhow::Ok(None);
1409            }
1410
1411            Ok(prediction)
1412        })
1413    }
1414
1415    async fn next_diagnostic_location(
1416        active_buffer: Entity<Buffer>,
1417        active_buffer_snapshot: &BufferSnapshot,
1418        active_buffer_diagnostic_search_range: Range<Point>,
1419        active_buffer_cursor_point: Point,
1420        project: &Entity<Project>,
1421        cx: &mut AsyncApp,
1422    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1423        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1424        let mut jump_location = active_buffer_snapshot
1425            .diagnostic_groups(None)
1426            .into_iter()
1427            .filter_map(|(_, group)| {
1428                let range = &group.entries[group.primary_ix]
1429                    .range
1430                    .to_point(&active_buffer_snapshot);
1431                if range.overlaps(&active_buffer_diagnostic_search_range) {
1432                    None
1433                } else {
1434                    Some(range.start)
1435                }
1436            })
1437            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1438            .map(|position| {
1439                (
1440                    active_buffer.clone(),
1441                    active_buffer_snapshot.anchor_before(position),
1442                )
1443            });
1444
1445        if jump_location.is_none() {
1446            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1447                let file = buffer.file()?;
1448
1449                Some(ProjectPath {
1450                    worktree_id: file.worktree_id(cx),
1451                    path: file.path().clone(),
1452                })
1453            })?;
1454
1455            let buffer_task = project.update(cx, |project, cx| {
1456                let (path, _, _) = project
1457                    .diagnostic_summaries(false, cx)
1458                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1459                    .max_by_key(|(path, _, _)| {
1460                        // find the buffer with errors that shares most parent directories
1461                        path.path
1462                            .components()
1463                            .zip(
1464                                active_buffer_path
1465                                    .as_ref()
1466                                    .map(|p| p.path.components())
1467                                    .unwrap_or_default(),
1468                            )
1469                            .take_while(|(a, b)| a == b)
1470                            .count()
1471                    })?;
1472
1473                Some(project.open_buffer(path, cx))
1474            })?;
1475
1476            if let Some(buffer_task) = buffer_task {
1477                let closest_buffer = buffer_task.await?;
1478
1479                jump_location = closest_buffer
1480                    .read_with(cx, |buffer, _cx| {
1481                        buffer
1482                            .buffer_diagnostics(None)
1483                            .into_iter()
1484                            .min_by_key(|entry| entry.diagnostic.severity)
1485                            .map(|entry| entry.range.start)
1486                    })?
1487                    .map(|position| (closest_buffer, position));
1488            }
1489        }
1490
1491        anyhow::Ok(jump_location)
1492    }
1493
1494    async fn send_raw_llm_request(
1495        request: open_ai::Request,
1496        client: Arc<Client>,
1497        llm_token: LlmApiToken,
1498        app_version: Version,
1499        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1500        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1501    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1502        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1503            http_client::Url::parse(&predict_edits_url)?
1504        } else {
1505            client
1506                .http_client()
1507                .build_zed_llm_url("/predict_edits/raw", &[])?
1508        };
1509
1510        #[cfg(feature = "eval-support")]
1511        let cache_key = if let Some(cache) = eval_cache {
1512            use collections::FxHasher;
1513            use std::hash::{Hash, Hasher};
1514
1515            let mut hasher = FxHasher::default();
1516            url.hash(&mut hasher);
1517            let request_str = serde_json::to_string_pretty(&request)?;
1518            request_str.hash(&mut hasher);
1519            let hash = hasher.finish();
1520
1521            let key = (eval_cache_kind, hash);
1522            if let Some(response_str) = cache.read(key) {
1523                return Ok((serde_json::from_str(&response_str)?, None));
1524            }
1525
1526            Some((cache, request_str, key))
1527        } else {
1528            None
1529        };
1530
1531        let (response, usage) = Self::send_api_request(
1532            |builder| {
1533                let req = builder
1534                    .uri(url.as_ref())
1535                    .body(serde_json::to_string(&request)?.into());
1536                Ok(req?)
1537            },
1538            client,
1539            llm_token,
1540            app_version,
1541        )
1542        .await?;
1543
1544        #[cfg(feature = "eval-support")]
1545        if let Some((cache, request, key)) = cache_key {
1546            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1547        }
1548
1549        Ok((response, usage))
1550    }
1551
1552    fn handle_api_response<T>(
1553        this: &WeakEntity<Self>,
1554        response: Result<(T, Option<EditPredictionUsage>)>,
1555        cx: &mut gpui::AsyncApp,
1556    ) -> Result<T> {
1557        match response {
1558            Ok((data, usage)) => {
1559                if let Some(usage) = usage {
1560                    this.update(cx, |this, cx| {
1561                        this.user_store.update(cx, |user_store, cx| {
1562                            user_store.update_edit_prediction_usage(usage, cx);
1563                        });
1564                    })
1565                    .ok();
1566                }
1567                Ok(data)
1568            }
1569            Err(err) => {
1570                if err.is::<ZedUpdateRequiredError>() {
1571                    cx.update(|cx| {
1572                        this.update(cx, |this, _cx| {
1573                            this.update_required = true;
1574                        })
1575                        .ok();
1576
1577                        let error_message: SharedString = err.to_string().into();
1578                        show_app_notification(
1579                            NotificationId::unique::<ZedUpdateRequiredError>(),
1580                            cx,
1581                            move |cx| {
1582                                cx.new(|cx| {
1583                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1584                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1585                                })
1586                            },
1587                        );
1588                    })
1589                    .ok();
1590                }
1591                Err(err)
1592            }
1593        }
1594    }
1595
1596    async fn send_api_request<Res>(
1597        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1598        client: Arc<Client>,
1599        llm_token: LlmApiToken,
1600        app_version: Version,
1601    ) -> Result<(Res, Option<EditPredictionUsage>)>
1602    where
1603        Res: DeserializeOwned,
1604    {
1605        let http_client = client.http_client();
1606        let mut token = llm_token.acquire(&client).await?;
1607        let mut did_retry = false;
1608
1609        loop {
1610            let request_builder = http_client::Request::builder().method(Method::POST);
1611
1612            let request = build(
1613                request_builder
1614                    .header("Content-Type", "application/json")
1615                    .header("Authorization", format!("Bearer {}", token))
1616                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1617            )?;
1618
1619            let mut response = http_client.send(request).await?;
1620
1621            if let Some(minimum_required_version) = response
1622                .headers()
1623                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1624                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1625            {
1626                anyhow::ensure!(
1627                    app_version >= minimum_required_version,
1628                    ZedUpdateRequiredError {
1629                        minimum_version: minimum_required_version
1630                    }
1631                );
1632            }
1633
1634            if response.status().is_success() {
1635                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1636
1637                let mut body = Vec::new();
1638                response.body_mut().read_to_end(&mut body).await?;
1639                return Ok((serde_json::from_slice(&body)?, usage));
1640            } else if !did_retry
1641                && response
1642                    .headers()
1643                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1644                    .is_some()
1645            {
1646                did_retry = true;
1647                token = llm_token.refresh(&client).await?;
1648            } else {
1649                let mut body = String::new();
1650                response.body_mut().read_to_string(&mut body).await?;
1651                anyhow::bail!(
1652                    "Request failed with status: {:?}\nBody: {}",
1653                    response.status(),
1654                    body
1655                );
1656            }
1657        }
1658    }
1659
1660    pub fn refresh_context(
1661        &mut self,
1662        project: &Entity<Project>,
1663        buffer: &Entity<language::Buffer>,
1664        cursor_position: language::Anchor,
1665        cx: &mut Context<Self>,
1666    ) {
1667        if self.use_context {
1668            self.get_or_init_project(project, cx)
1669                .context
1670                .update(cx, |store, cx| {
1671                    store.refresh(buffer.clone(), cursor_position, cx);
1672                });
1673        }
1674    }
1675
1676    fn is_file_open_source(
1677        &self,
1678        project: &Entity<Project>,
1679        file: &Arc<dyn File>,
1680        cx: &App,
1681    ) -> bool {
1682        if !file.is_local() || file.is_private() {
1683            return false;
1684        }
1685        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1686            return false;
1687        };
1688        project_state
1689            .license_detection_watchers
1690            .get(&file.worktree_id(cx))
1691            .as_ref()
1692            .is_some_and(|watcher| watcher.is_project_open_source())
1693    }
1694
1695    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1696        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1697    }
1698
1699    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
1700        if !self.data_collection_choice.is_enabled() {
1701            return false;
1702        }
1703        events.iter().all(|event| {
1704            matches!(
1705                event.as_ref(),
1706                Event::BufferChange {
1707                    in_open_source_repo: true,
1708                    ..
1709                }
1710            )
1711        })
1712    }
1713
1714    fn load_data_collection_choice() -> DataCollectionChoice {
1715        let choice = KEY_VALUE_STORE
1716            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1717            .log_err()
1718            .flatten();
1719
1720        match choice.as_deref() {
1721            Some("true") => DataCollectionChoice::Enabled,
1722            Some("false") => DataCollectionChoice::Disabled,
1723            Some(_) => {
1724                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1725                DataCollectionChoice::NotAnswered
1726            }
1727            None => DataCollectionChoice::NotAnswered,
1728        }
1729    }
1730
1731    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1732        self.data_collection_choice = self.data_collection_choice.toggle();
1733        let new_choice = self.data_collection_choice;
1734        db::write_and_log(cx, move || {
1735            KEY_VALUE_STORE.write_kvp(
1736                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1737                new_choice.is_enabled().to_string(),
1738            )
1739        });
1740    }
1741
1742    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1743        self.shown_predictions.iter()
1744    }
1745
1746    pub fn shown_completions_len(&self) -> usize {
1747        self.shown_predictions.len()
1748    }
1749
1750    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1751        self.rated_predictions.contains(id)
1752    }
1753
1754    pub fn rate_prediction(
1755        &mut self,
1756        prediction: &EditPrediction,
1757        rating: EditPredictionRating,
1758        feedback: String,
1759        cx: &mut Context<Self>,
1760    ) {
1761        self.rated_predictions.insert(prediction.id.clone());
1762        telemetry::event!(
1763            "Edit Prediction Rated",
1764            rating,
1765            inputs = prediction.inputs,
1766            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1767            feedback
1768        );
1769        self.client.telemetry().flush_events().detach();
1770        cx.notify();
1771    }
1772
1773    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1774        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1775            && all_language_settings(None, cx).edit_predictions.use_context;
1776    }
1777}
1778
1779#[derive(Error, Debug)]
1780#[error(
1781    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1782)]
1783pub struct ZedUpdateRequiredError {
1784    minimum_version: Version,
1785}
1786
1787#[cfg(feature = "eval-support")]
1788pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1789
1790#[cfg(feature = "eval-support")]
1791#[derive(Debug, Clone, Copy, PartialEq)]
1792pub enum EvalCacheEntryKind {
1793    Context,
1794    Search,
1795    Prediction,
1796}
1797
1798#[cfg(feature = "eval-support")]
1799impl std::fmt::Display for EvalCacheEntryKind {
1800    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1801        match self {
1802            EvalCacheEntryKind::Search => write!(f, "search"),
1803            EvalCacheEntryKind::Context => write!(f, "context"),
1804            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1805        }
1806    }
1807}
1808
1809#[cfg(feature = "eval-support")]
1810pub trait EvalCache: Send + Sync {
1811    fn read(&self, key: EvalCacheKey) -> Option<String>;
1812    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1813}
1814
1815#[derive(Debug, Clone, Copy)]
1816pub enum DataCollectionChoice {
1817    NotAnswered,
1818    Enabled,
1819    Disabled,
1820}
1821
1822impl DataCollectionChoice {
1823    pub fn is_enabled(self) -> bool {
1824        match self {
1825            Self::Enabled => true,
1826            Self::NotAnswered | Self::Disabled => false,
1827        }
1828    }
1829
1830    pub fn is_answered(self) -> bool {
1831        match self {
1832            Self::Enabled | Self::Disabled => true,
1833            Self::NotAnswered => false,
1834        }
1835    }
1836
1837    #[must_use]
1838    pub fn toggle(&self) -> DataCollectionChoice {
1839        match self {
1840            Self::Enabled => Self::Disabled,
1841            Self::Disabled => Self::Enabled,
1842            Self::NotAnswered => Self::Enabled,
1843        }
1844    }
1845}
1846
1847impl From<bool> for DataCollectionChoice {
1848    fn from(value: bool) -> Self {
1849        match value {
1850            true => DataCollectionChoice::Enabled,
1851            false => DataCollectionChoice::Disabled,
1852        }
1853    }
1854}
1855
1856struct ZedPredictUpsell;
1857
1858impl Dismissable for ZedPredictUpsell {
1859    const KEY: &'static str = "dismissed-edit-predict-upsell";
1860
1861    fn dismissed() -> bool {
1862        // To make this backwards compatible with older versions of Zed, we
1863        // check if the user has seen the previous Edit Prediction Onboarding
1864        // before, by checking the data collection choice which was written to
1865        // the database once the user clicked on "Accept and Enable"
1866        if KEY_VALUE_STORE
1867            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1868            .log_err()
1869            .is_some_and(|s| s.is_some())
1870        {
1871            return true;
1872        }
1873
1874        KEY_VALUE_STORE
1875            .read_kvp(Self::KEY)
1876            .log_err()
1877            .is_some_and(|s| s.is_some())
1878    }
1879}
1880
1881pub fn should_show_upsell_modal() -> bool {
1882    !ZedPredictUpsell::dismissed()
1883}
1884
1885pub fn init(cx: &mut App) {
1886    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1887        workspace.register_action(
1888            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1889                ZedPredictModal::toggle(
1890                    workspace,
1891                    workspace.user_store().clone(),
1892                    workspace.client().clone(),
1893                    window,
1894                    cx,
1895                )
1896            },
1897        );
1898
1899        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
1900            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
1901                settings
1902                    .project
1903                    .all_languages
1904                    .features
1905                    .get_or_insert_default()
1906                    .edit_prediction_provider = Some(EditPredictionProvider::None)
1907            });
1908        });
1909    })
1910    .detach();
1911}