edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use std::env;
  40use text::Edit;
  41use workspace::Workspace;
  42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  43
  44use std::mem;
  45use std::ops::Range;
  46use std::path::Path;
  47use std::rc::Rc;
  48use std::str::FromStr as _;
  49use std::sync::Arc;
  50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59pub mod ollama;
  60mod onboarding_modal;
  61pub mod open_ai_response;
  62mod prediction;
  63pub mod sweep_ai;
  64
  65pub mod udiff;
  66
  67mod capture_example;
  68mod zed_edit_prediction_delegate;
  69pub mod zeta1;
  70pub mod zeta2;
  71
  72#[cfg(test)]
  73mod edit_prediction_tests;
  74
  75use crate::capture_example::should_send_testing_zeta2_request;
  76use crate::license_detection::LicenseDetectionWatcher;
  77use crate::mercury::Mercury;
  78use crate::ollama::Ollama;
  79use crate::onboarding_modal::ZedPredictModal;
  80pub use crate::prediction::EditPrediction;
  81pub use crate::prediction::EditPredictionId;
  82use crate::prediction::EditPredictionResult;
  83pub use crate::sweep_ai::SweepAi;
  84pub use capture_example::capture_example;
  85pub use language_model::ApiKeyState;
  86pub use telemetry_events::EditPredictionRating;
  87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 105
 106pub struct Zeta2FeatureFlag;
 107
 108impl FeatureFlag for Zeta2FeatureFlag {
 109    const NAME: &'static str = "zeta2";
 110
 111    fn enabled_for_staff() -> bool {
 112        true
 113    }
 114}
 115
 116#[derive(Clone)]
 117struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 118
 119impl Global for EditPredictionStoreGlobal {}
 120
 121/// Configuration for using the raw Zeta2 endpoint.
 122/// When set, the client uses the raw endpoint and constructs the prompt itself.
 123/// The version is also used as the Baseten environment name (lowercased).
 124#[derive(Clone)]
 125pub struct Zeta2RawConfig {
 126    pub model_id: Option<String>,
 127    pub format: ZetaFormat,
 128}
 129
 130pub struct EditPredictionStore {
 131    client: Arc<Client>,
 132    user_store: Entity<UserStore>,
 133    llm_token: LlmApiToken,
 134    _llm_token_subscription: Subscription,
 135    projects: HashMap<EntityId, ProjectState>,
 136    update_required: bool,
 137    edit_prediction_model: EditPredictionModel,
 138    zeta2_raw_config: Option<Zeta2RawConfig>,
 139    pub sweep_ai: SweepAi,
 140    pub mercury: Mercury,
 141    pub ollama: Ollama,
 142    data_collection_choice: DataCollectionChoice,
 143    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 144    shown_predictions: VecDeque<EditPrediction>,
 145    rated_predictions: HashSet<EditPredictionId>,
 146}
 147
 148#[derive(Copy, Clone, Default, PartialEq, Eq)]
 149pub enum EditPredictionModel {
 150    #[default]
 151    Zeta1,
 152    Zeta2,
 153    Sweep,
 154    Mercury,
 155    Ollama,
 156}
 157
 158#[derive(Clone)]
 159pub struct EditPredictionModelInput {
 160    project: Entity<Project>,
 161    buffer: Entity<Buffer>,
 162    snapshot: BufferSnapshot,
 163    position: Anchor,
 164    events: Vec<Arc<zeta_prompt::Event>>,
 165    related_files: Vec<RelatedFile>,
 166    recent_paths: VecDeque<ProjectPath>,
 167    trigger: PredictEditsRequestTrigger,
 168    diagnostic_search_range: Range<Point>,
 169    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 170    pub user_actions: Vec<UserActionRecord>,
 171}
 172
 173#[derive(Debug)]
 174pub enum DebugEvent {
 175    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 176    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 177    EditPredictionStarted(EditPredictionStartedDebugEvent),
 178    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 179}
 180
 181#[derive(Debug)]
 182pub struct ContextRetrievalStartedDebugEvent {
 183    pub project_entity_id: EntityId,
 184    pub timestamp: Instant,
 185    pub search_prompt: String,
 186}
 187
 188#[derive(Debug)]
 189pub struct ContextRetrievalFinishedDebugEvent {
 190    pub project_entity_id: EntityId,
 191    pub timestamp: Instant,
 192    pub metadata: Vec<(&'static str, SharedString)>,
 193}
 194
 195#[derive(Debug)]
 196pub struct EditPredictionStartedDebugEvent {
 197    pub buffer: WeakEntity<Buffer>,
 198    pub position: Anchor,
 199    pub prompt: Option<String>,
 200}
 201
 202#[derive(Debug)]
 203pub struct EditPredictionFinishedDebugEvent {
 204    pub buffer: WeakEntity<Buffer>,
 205    pub position: Anchor,
 206    pub model_output: Option<String>,
 207}
 208
 209const USER_ACTION_HISTORY_SIZE: usize = 16;
 210
 211#[derive(Clone, Debug)]
 212pub struct UserActionRecord {
 213    pub action_type: UserActionType,
 214    pub buffer_id: EntityId,
 215    pub line_number: u32,
 216    pub offset: usize,
 217    pub timestamp_epoch_ms: u64,
 218}
 219
 220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 221pub enum UserActionType {
 222    InsertChar,
 223    InsertSelection,
 224    DeleteChar,
 225    DeleteSelection,
 226    CursorMovement,
 227}
 228
 229/// An event with associated metadata for reconstructing buffer state.
 230#[derive(Clone)]
 231pub struct StoredEvent {
 232    pub event: Arc<zeta_prompt::Event>,
 233    pub old_snapshot: TextBufferSnapshot,
 234}
 235
 236struct ProjectState {
 237    events: VecDeque<StoredEvent>,
 238    last_event: Option<LastEvent>,
 239    recent_paths: VecDeque<ProjectPath>,
 240    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 241    current_prediction: Option<CurrentEditPrediction>,
 242    next_pending_prediction_id: usize,
 243    pending_predictions: ArrayVec<PendingPrediction, 2>,
 244    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 245    last_prediction_refresh: Option<(EntityId, Instant)>,
 246    cancelled_predictions: HashSet<usize>,
 247    context: Entity<RelatedExcerptStore>,
 248    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 249    user_actions: VecDeque<UserActionRecord>,
 250    _subscriptions: [gpui::Subscription; 2],
 251    copilot: Option<Entity<Copilot>>,
 252}
 253
 254impl ProjectState {
 255    fn record_user_action(&mut self, action: UserActionRecord) {
 256        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 257            self.user_actions.pop_front();
 258        }
 259        self.user_actions.push_back(action);
 260    }
 261
 262    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 263        self.events
 264            .iter()
 265            .cloned()
 266            .chain(
 267                self.last_event
 268                    .as_ref()
 269                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 270            )
 271            .collect()
 272    }
 273
 274    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 275        self.events
 276            .iter()
 277            .cloned()
 278            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 279                let (one, two) = event.split_by_pause();
 280                let one = one.finalize(&self.license_detection_watchers, cx);
 281                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 282                one.into_iter().chain(two)
 283            }))
 284            .collect()
 285    }
 286
 287    fn cancel_pending_prediction(
 288        &mut self,
 289        pending_prediction: PendingPrediction,
 290        cx: &mut Context<EditPredictionStore>,
 291    ) {
 292        self.cancelled_predictions.insert(pending_prediction.id);
 293
 294        if pending_prediction.drop_on_cancel {
 295            drop(pending_prediction.task);
 296        } else {
 297            cx.spawn(async move |this, cx| {
 298                let Some(prediction_id) = pending_prediction.task.await else {
 299                    return;
 300                };
 301
 302                this.update(cx, |this, cx| {
 303                    this.reject_prediction(
 304                        prediction_id,
 305                        EditPredictionRejectReason::Canceled,
 306                        false,
 307                        cx,
 308                    );
 309                })
 310                .ok();
 311            })
 312            .detach()
 313        }
 314    }
 315
 316    fn active_buffer(
 317        &self,
 318        project: &Entity<Project>,
 319        cx: &App,
 320    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 321        let project = project.read(cx);
 322        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 323        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 324        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 325        Some((active_buffer, registered_buffer.last_position))
 326    }
 327}
 328
 329#[derive(Debug, Clone)]
 330struct CurrentEditPrediction {
 331    pub requested_by: PredictionRequestedBy,
 332    pub prediction: EditPrediction,
 333    pub was_shown: bool,
 334    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 335}
 336
 337impl CurrentEditPrediction {
 338    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 339        let Some(new_edits) = self
 340            .prediction
 341            .interpolate(&self.prediction.buffer.read(cx))
 342        else {
 343            return false;
 344        };
 345
 346        if self.prediction.buffer != old_prediction.prediction.buffer {
 347            return true;
 348        }
 349
 350        let Some(old_edits) = old_prediction
 351            .prediction
 352            .interpolate(&old_prediction.prediction.buffer.read(cx))
 353        else {
 354            return true;
 355        };
 356
 357        let requested_by_buffer_id = self.requested_by.buffer_id();
 358
 359        // This reduces the occurrence of UI thrash from replacing edits
 360        //
 361        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 362        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 363            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 364            && old_edits.len() == 1
 365            && new_edits.len() == 1
 366        {
 367            let (old_range, old_text) = &old_edits[0];
 368            let (new_range, new_text) = &new_edits[0];
 369            new_range == old_range && new_text.starts_with(old_text.as_ref())
 370        } else {
 371            true
 372        }
 373    }
 374}
 375
 376#[derive(Debug, Clone)]
 377enum PredictionRequestedBy {
 378    DiagnosticsUpdate,
 379    Buffer(EntityId),
 380}
 381
 382impl PredictionRequestedBy {
 383    pub fn buffer_id(&self) -> Option<EntityId> {
 384        match self {
 385            PredictionRequestedBy::DiagnosticsUpdate => None,
 386            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 387        }
 388    }
 389}
 390
 391#[derive(Debug)]
 392struct PendingPrediction {
 393    id: usize,
 394    task: Task<Option<EditPredictionId>>,
 395    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 396    /// If false, the task is awaited to completion so rejection can be reported.
 397    drop_on_cancel: bool,
 398}
 399
 400/// A prediction from the perspective of a buffer.
 401#[derive(Debug)]
 402enum BufferEditPrediction<'a> {
 403    Local { prediction: &'a EditPrediction },
 404    Jump { prediction: &'a EditPrediction },
 405}
 406
 407#[cfg(test)]
 408impl std::ops::Deref for BufferEditPrediction<'_> {
 409    type Target = EditPrediction;
 410
 411    fn deref(&self) -> &Self::Target {
 412        match self {
 413            BufferEditPrediction::Local { prediction } => prediction,
 414            BufferEditPrediction::Jump { prediction } => prediction,
 415        }
 416    }
 417}
 418
 419struct RegisteredBuffer {
 420    file: Option<Arc<dyn File>>,
 421    snapshot: TextBufferSnapshot,
 422    last_position: Option<Anchor>,
 423    _subscriptions: [gpui::Subscription; 2],
 424}
 425
 426#[derive(Clone)]
 427struct LastEvent {
 428    old_snapshot: TextBufferSnapshot,
 429    new_snapshot: TextBufferSnapshot,
 430    old_file: Option<Arc<dyn File>>,
 431    new_file: Option<Arc<dyn File>>,
 432    edit_range: Option<Range<Anchor>>,
 433    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 434    last_edit_time: Option<Instant>,
 435}
 436
 437impl LastEvent {
 438    pub fn finalize(
 439        &self,
 440        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 441        cx: &App,
 442    ) -> Option<StoredEvent> {
 443        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 444        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 445
 446        let in_open_source_repo =
 447            [self.new_file.as_ref(), self.old_file.as_ref()]
 448                .iter()
 449                .all(|file| {
 450                    file.is_some_and(|file| {
 451                        license_detection_watchers
 452                            .get(&file.worktree_id(cx))
 453                            .is_some_and(|watcher| watcher.is_project_open_source())
 454                    })
 455                });
 456
 457        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 458
 459        if path == old_path && diff.is_empty() {
 460            None
 461        } else {
 462            Some(StoredEvent {
 463                event: Arc::new(zeta_prompt::Event::BufferChange {
 464                    old_path,
 465                    path,
 466                    diff,
 467                    in_open_source_repo,
 468                    // TODO: Actually detect if this edit was predicted or not
 469                    predicted: false,
 470                }),
 471                old_snapshot: self.old_snapshot.clone(),
 472            })
 473        }
 474    }
 475
 476    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 477        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 478            return (self.clone(), None);
 479        };
 480
 481        let before = LastEvent {
 482            old_snapshot: self.old_snapshot.clone(),
 483            new_snapshot: boundary_snapshot.clone(),
 484            old_file: self.old_file.clone(),
 485            new_file: self.new_file.clone(),
 486            edit_range: None,
 487            snapshot_after_last_editing_pause: None,
 488            last_edit_time: self.last_edit_time,
 489        };
 490
 491        let after = LastEvent {
 492            old_snapshot: boundary_snapshot.clone(),
 493            new_snapshot: self.new_snapshot.clone(),
 494            old_file: self.old_file.clone(),
 495            new_file: self.new_file.clone(),
 496            edit_range: None,
 497            snapshot_after_last_editing_pause: None,
 498            last_edit_time: self.last_edit_time,
 499        };
 500
 501        (before, Some(after))
 502    }
 503}
 504
 505pub(crate) fn compute_diff_between_snapshots(
 506    old_snapshot: &TextBufferSnapshot,
 507    new_snapshot: &TextBufferSnapshot,
 508) -> Option<String> {
 509    let edits: Vec<Edit<usize>> = new_snapshot
 510        .edits_since::<usize>(&old_snapshot.version)
 511        .collect();
 512
 513    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 514
 515    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 516    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 517    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 518    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 519
 520    const CONTEXT_LINES: u32 = 3;
 521
 522    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 523    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 524    let old_context_end_row =
 525        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 526    let new_context_end_row =
 527        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 528
 529    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 530    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 531    let old_end_line_offset = old_snapshot
 532        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 533    let new_end_line_offset = new_snapshot
 534        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 535    let old_edit_range = old_start_line_offset..old_end_line_offset;
 536    let new_edit_range = new_start_line_offset..new_end_line_offset;
 537
 538    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 539    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 540
 541    let diff = language::unified_diff_with_offsets(
 542        &old_region_text,
 543        &new_region_text,
 544        old_context_start_row,
 545        new_context_start_row,
 546    );
 547
 548    Some(diff)
 549}
 550
 551fn buffer_path_with_id_fallback(
 552    file: Option<&Arc<dyn File>>,
 553    snapshot: &TextBufferSnapshot,
 554    cx: &App,
 555) -> Arc<Path> {
 556    if let Some(file) = file {
 557        file.full_path(cx).into()
 558    } else {
 559        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 560    }
 561}
 562
 563impl EditPredictionStore {
 564    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 565        cx.try_global::<EditPredictionStoreGlobal>()
 566            .map(|global| global.0.clone())
 567    }
 568
 569    pub fn global(
 570        client: &Arc<Client>,
 571        user_store: &Entity<UserStore>,
 572        cx: &mut App,
 573    ) -> Entity<Self> {
 574        cx.try_global::<EditPredictionStoreGlobal>()
 575            .map(|global| global.0.clone())
 576            .unwrap_or_else(|| {
 577                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 578                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 579                ep_store
 580            })
 581    }
 582
 583    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 584        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 585        let data_collection_choice = Self::load_data_collection_choice();
 586
 587        let llm_token = LlmApiToken::default();
 588
 589        let (reject_tx, reject_rx) = mpsc::unbounded();
 590        cx.background_spawn({
 591            let client = client.clone();
 592            let llm_token = llm_token.clone();
 593            let app_version = AppVersion::global(cx);
 594            let background_executor = cx.background_executor().clone();
 595            async move {
 596                Self::handle_rejected_predictions(
 597                    reject_rx,
 598                    client,
 599                    llm_token,
 600                    app_version,
 601                    background_executor,
 602                )
 603                .await
 604            }
 605        })
 606        .detach();
 607
 608        let this = Self {
 609            projects: HashMap::default(),
 610            client,
 611            user_store,
 612            llm_token,
 613            _llm_token_subscription: cx.subscribe(
 614                &refresh_llm_token_listener,
 615                |this, _listener, _event, cx| {
 616                    let client = this.client.clone();
 617                    let llm_token = this.llm_token.clone();
 618                    cx.spawn(async move |_this, _cx| {
 619                        llm_token.refresh(&client).await?;
 620                        anyhow::Ok(())
 621                    })
 622                    .detach_and_log_err(cx);
 623                },
 624            ),
 625            update_required: false,
 626            edit_prediction_model: EditPredictionModel::Zeta2,
 627            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 628            sweep_ai: SweepAi::new(cx),
 629            mercury: Mercury::new(cx),
 630            ollama: Ollama::new(),
 631
 632            data_collection_choice,
 633            reject_predictions_tx: reject_tx,
 634            rated_predictions: Default::default(),
 635            shown_predictions: Default::default(),
 636        };
 637
 638        this
 639    }
 640
 641    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 642        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 643        let format = ZetaFormat::parse(&version_str).ok()?;
 644        let model_id = env::var("ZED_ZETA_MODEL").ok();
 645        Some(Zeta2RawConfig { model_id, format })
 646    }
 647
 648    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 649        self.edit_prediction_model = model;
 650    }
 651
 652    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 653        self.zeta2_raw_config = Some(config);
 654    }
 655
 656    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 657        self.zeta2_raw_config.as_ref()
 658    }
 659
 660    pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
 661        use ui::IconName;
 662        match self.edit_prediction_model {
 663            EditPredictionModel::Sweep => {
 664                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 665                    .with_disabled(IconName::SweepAiDisabled)
 666                    .with_up(IconName::SweepAiUp)
 667                    .with_down(IconName::SweepAiDown)
 668                    .with_error(IconName::SweepAiError)
 669            }
 670            EditPredictionModel::Mercury => {
 671                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 672            }
 673            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 674                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 675                    .with_disabled(IconName::ZedPredictDisabled)
 676                    .with_up(IconName::ZedPredictUp)
 677                    .with_down(IconName::ZedPredictDown)
 678                    .with_error(IconName::ZedPredictError)
 679            }
 680            EditPredictionModel::Ollama => {
 681                edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 682            }
 683        }
 684    }
 685
 686    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 687        self.sweep_ai.api_token.read(cx).has_key()
 688    }
 689
 690    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 691        self.mercury.api_token.read(cx).has_key()
 692    }
 693
 694    pub fn clear_history(&mut self) {
 695        for project_state in self.projects.values_mut() {
 696            project_state.events.clear();
 697            project_state.last_event.take();
 698        }
 699    }
 700
 701    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 702        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 703            project_state.events.clear();
 704            project_state.last_event.take();
 705        }
 706    }
 707
 708    pub fn edit_history_for_project(
 709        &self,
 710        project: &Entity<Project>,
 711        cx: &App,
 712    ) -> Vec<StoredEvent> {
 713        self.projects
 714            .get(&project.entity_id())
 715            .map(|project_state| project_state.events(cx))
 716            .unwrap_or_default()
 717    }
 718
 719    pub fn edit_history_for_project_with_pause_split_last_event(
 720        &self,
 721        project: &Entity<Project>,
 722        cx: &App,
 723    ) -> Vec<StoredEvent> {
 724        self.projects
 725            .get(&project.entity_id())
 726            .map(|project_state| project_state.events_split_by_pause(cx))
 727            .unwrap_or_default()
 728    }
 729
 730    pub fn context_for_project<'a>(
 731        &'a self,
 732        project: &Entity<Project>,
 733        cx: &'a mut App,
 734    ) -> Vec<RelatedFile> {
 735        self.projects
 736            .get(&project.entity_id())
 737            .map(|project| {
 738                project
 739                    .context
 740                    .update(cx, |context, cx| context.related_files(cx))
 741            })
 742            .unwrap_or_default()
 743    }
 744
 745    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 746        self.projects
 747            .get(&project.entity_id())
 748            .and_then(|project| project.copilot.clone())
 749    }
 750
 751    pub fn start_copilot_for_project(
 752        &mut self,
 753        project: &Entity<Project>,
 754        cx: &mut Context<Self>,
 755    ) -> Option<Entity<Copilot>> {
 756        if DisableAiSettings::get(None, cx).disable_ai {
 757            return None;
 758        }
 759        let state = self.get_or_init_project(project, cx);
 760
 761        if state.copilot.is_some() {
 762            return state.copilot.clone();
 763        }
 764        let _project = project.clone();
 765        let project = project.read(cx);
 766
 767        let node = project.node_runtime().cloned();
 768        if let Some(node) = node {
 769            let next_id = project.languages().next_language_server_id();
 770            let fs = project.fs().clone();
 771
 772            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 773            state.copilot = Some(copilot.clone());
 774            Some(copilot)
 775        } else {
 776            None
 777        }
 778    }
 779
 780    pub fn context_for_project_with_buffers<'a>(
 781        &'a self,
 782        project: &Entity<Project>,
 783        cx: &'a mut App,
 784    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 785        self.projects
 786            .get(&project.entity_id())
 787            .map(|project| {
 788                project
 789                    .context
 790                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 791            })
 792            .unwrap_or_default()
 793    }
 794
 795    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 796        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
 797            self.user_store.read(cx).edit_prediction_usage()
 798        } else {
 799            None
 800        }
 801    }
 802
 803    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 804        self.get_or_init_project(project, cx);
 805    }
 806
 807    pub fn register_buffer(
 808        &mut self,
 809        buffer: &Entity<Buffer>,
 810        project: &Entity<Project>,
 811        cx: &mut Context<Self>,
 812    ) {
 813        let project_state = self.get_or_init_project(project, cx);
 814        Self::register_buffer_impl(project_state, buffer, project, cx);
 815    }
 816
 817    fn get_or_init_project(
 818        &mut self,
 819        project: &Entity<Project>,
 820        cx: &mut Context<Self>,
 821    ) -> &mut ProjectState {
 822        let entity_id = project.entity_id();
 823        self.projects
 824            .entry(entity_id)
 825            .or_insert_with(|| ProjectState {
 826                context: {
 827                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 828                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 829                        this.handle_excerpt_store_event(entity_id, event);
 830                    })
 831                    .detach();
 832                    related_excerpt_store
 833                },
 834                events: VecDeque::new(),
 835                last_event: None,
 836                recent_paths: VecDeque::new(),
 837                debug_tx: None,
 838                registered_buffers: HashMap::default(),
 839                current_prediction: None,
 840                cancelled_predictions: HashSet::default(),
 841                pending_predictions: ArrayVec::new(),
 842                next_pending_prediction_id: 0,
 843                last_prediction_refresh: None,
 844                license_detection_watchers: HashMap::default(),
 845                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 846                _subscriptions: [
 847                    cx.subscribe(&project, Self::handle_project_event),
 848                    cx.observe_release(&project, move |this, _, cx| {
 849                        this.projects.remove(&entity_id);
 850                        cx.notify();
 851                    }),
 852                ],
 853                copilot: None,
 854            })
 855    }
 856
 857    pub fn remove_project(&mut self, project: &Entity<Project>) {
 858        self.projects.remove(&project.entity_id());
 859    }
 860
 861    fn handle_excerpt_store_event(
 862        &mut self,
 863        project_entity_id: EntityId,
 864        event: &RelatedExcerptStoreEvent,
 865    ) {
 866        if let Some(project_state) = self.projects.get(&project_entity_id) {
 867            if let Some(debug_tx) = project_state.debug_tx.clone() {
 868                match event {
 869                    RelatedExcerptStoreEvent::StartedRefresh => {
 870                        debug_tx
 871                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 872                                ContextRetrievalStartedDebugEvent {
 873                                    project_entity_id: project_entity_id,
 874                                    timestamp: Instant::now(),
 875                                    search_prompt: String::new(),
 876                                },
 877                            ))
 878                            .ok();
 879                    }
 880                    RelatedExcerptStoreEvent::FinishedRefresh {
 881                        cache_hit_count,
 882                        cache_miss_count,
 883                        mean_definition_latency,
 884                        max_definition_latency,
 885                    } => {
 886                        debug_tx
 887                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 888                                ContextRetrievalFinishedDebugEvent {
 889                                    project_entity_id: project_entity_id,
 890                                    timestamp: Instant::now(),
 891                                    metadata: vec![
 892                                        (
 893                                            "Cache Hits",
 894                                            format!(
 895                                                "{}/{}",
 896                                                cache_hit_count,
 897                                                cache_hit_count + cache_miss_count
 898                                            )
 899                                            .into(),
 900                                        ),
 901                                        (
 902                                            "Max LSP Time",
 903                                            format!("{} ms", max_definition_latency.as_millis())
 904                                                .into(),
 905                                        ),
 906                                        (
 907                                            "Mean LSP Time",
 908                                            format!("{} ms", mean_definition_latency.as_millis())
 909                                                .into(),
 910                                        ),
 911                                    ],
 912                                },
 913                            ))
 914                            .ok();
 915                    }
 916                }
 917            }
 918        }
 919    }
 920
 921    pub fn debug_info(
 922        &mut self,
 923        project: &Entity<Project>,
 924        cx: &mut Context<Self>,
 925    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 926        let project_state = self.get_or_init_project(project, cx);
 927        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 928        project_state.debug_tx = Some(debug_watch_tx);
 929        debug_watch_rx
 930    }
 931
 932    fn handle_project_event(
 933        &mut self,
 934        project: Entity<Project>,
 935        event: &project::Event,
 936        cx: &mut Context<Self>,
 937    ) {
 938        // TODO [zeta2] init with recent paths
 939        match event {
 940            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 941                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 942                    return;
 943                };
 944                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 945                if let Some(path) = path {
 946                    if let Some(ix) = project_state
 947                        .recent_paths
 948                        .iter()
 949                        .position(|probe| probe == &path)
 950                    {
 951                        project_state.recent_paths.remove(ix);
 952                    }
 953                    project_state.recent_paths.push_front(path);
 954                }
 955            }
 956            project::Event::DiagnosticsUpdated { .. } => {
 957                if cx.has_flag::<Zeta2FeatureFlag>() {
 958                    self.refresh_prediction_from_diagnostics(project, cx);
 959                }
 960            }
 961            _ => (),
 962        }
 963    }
 964
 965    fn register_buffer_impl<'a>(
 966        project_state: &'a mut ProjectState,
 967        buffer: &Entity<Buffer>,
 968        project: &Entity<Project>,
 969        cx: &mut Context<Self>,
 970    ) -> &'a mut RegisteredBuffer {
 971        let buffer_id = buffer.entity_id();
 972
 973        if let Some(file) = buffer.read(cx).file() {
 974            let worktree_id = file.worktree_id(cx);
 975            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 976                project_state
 977                    .license_detection_watchers
 978                    .entry(worktree_id)
 979                    .or_insert_with(|| {
 980                        let project_entity_id = project.entity_id();
 981                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 982                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 983                            else {
 984                                return;
 985                            };
 986                            project_state
 987                                .license_detection_watchers
 988                                .remove(&worktree_id);
 989                        })
 990                        .detach();
 991                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 992                    });
 993            }
 994        }
 995
 996        match project_state.registered_buffers.entry(buffer_id) {
 997            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 998            hash_map::Entry::Vacant(entry) => {
 999                let buf = buffer.read(cx);
1000                let snapshot = buf.text_snapshot();
1001                let file = buf.file().cloned();
1002                let project_entity_id = project.entity_id();
1003                entry.insert(RegisteredBuffer {
1004                    snapshot,
1005                    file,
1006                    last_position: None,
1007                    _subscriptions: [
1008                        cx.subscribe(buffer, {
1009                            let project = project.downgrade();
1010                            move |this, buffer, event, cx| {
1011                                if let language::BufferEvent::Edited = event
1012                                    && let Some(project) = project.upgrade()
1013                                {
1014                                    this.report_changes_for_buffer(&buffer, &project, cx);
1015                                }
1016                            }
1017                        }),
1018                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1019                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1020                            else {
1021                                return;
1022                            };
1023                            project_state.registered_buffers.remove(&buffer_id);
1024                        }),
1025                    ],
1026                })
1027            }
1028        }
1029    }
1030
1031    fn report_changes_for_buffer(
1032        &mut self,
1033        buffer: &Entity<Buffer>,
1034        project: &Entity<Project>,
1035        cx: &mut Context<Self>,
1036    ) {
1037        let project_state = self.get_or_init_project(project, cx);
1038        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1039
1040        let buf = buffer.read(cx);
1041        let new_file = buf.file().cloned();
1042        let new_snapshot = buf.text_snapshot();
1043        if new_snapshot.version == registered_buffer.snapshot.version {
1044            return;
1045        }
1046
1047        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1048        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1049        let mut num_edits = 0usize;
1050        let mut total_deleted = 0usize;
1051        let mut total_inserted = 0usize;
1052        let mut edit_range: Option<Range<Anchor>> = None;
1053        let mut last_offset: Option<usize> = None;
1054
1055        for (edit, anchor_range) in
1056            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1057        {
1058            num_edits += 1;
1059            total_deleted += edit.old.len();
1060            total_inserted += edit.new.len();
1061            edit_range = Some(match edit_range {
1062                None => anchor_range,
1063                Some(acc) => acc.start..anchor_range.end,
1064            });
1065            last_offset = Some(edit.new.end);
1066        }
1067
1068        if num_edits > 0 {
1069            let action_type = match (total_deleted, total_inserted, num_edits) {
1070                (0, ins, n) if ins == n => UserActionType::InsertChar,
1071                (0, _, _) => UserActionType::InsertSelection,
1072                (del, 0, n) if del == n => UserActionType::DeleteChar,
1073                (_, 0, _) => UserActionType::DeleteSelection,
1074                (_, ins, n) if ins == n => UserActionType::InsertChar,
1075                (_, _, _) => UserActionType::InsertSelection,
1076            };
1077
1078            if let Some(offset) = last_offset {
1079                let point = new_snapshot.offset_to_point(offset);
1080                let timestamp_epoch_ms = SystemTime::now()
1081                    .duration_since(UNIX_EPOCH)
1082                    .map(|d| d.as_millis() as u64)
1083                    .unwrap_or(0);
1084                project_state.record_user_action(UserActionRecord {
1085                    action_type,
1086                    buffer_id: buffer.entity_id(),
1087                    line_number: point.row,
1088                    offset,
1089                    timestamp_epoch_ms,
1090                });
1091            }
1092        }
1093
1094        let events = &mut project_state.events;
1095
1096        let now = cx.background_executor().now();
1097        if let Some(last_event) = project_state.last_event.as_mut() {
1098            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1099                == last_event.new_snapshot.remote_id()
1100                && old_snapshot.version == last_event.new_snapshot.version;
1101
1102            let should_coalesce = is_next_snapshot_of_same_buffer
1103                && edit_range
1104                    .as_ref()
1105                    .zip(last_event.edit_range.as_ref())
1106                    .is_some_and(|(a, b)| {
1107                        let a = a.to_point(&new_snapshot);
1108                        let b = b.to_point(&new_snapshot);
1109                        if a.start > b.end {
1110                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1111                        } else if b.start > a.end {
1112                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1113                        } else {
1114                            true
1115                        }
1116                    });
1117
1118            if should_coalesce {
1119                let pause_elapsed = last_event
1120                    .last_edit_time
1121                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1122                    .unwrap_or(false);
1123                if pause_elapsed {
1124                    last_event.snapshot_after_last_editing_pause =
1125                        Some(last_event.new_snapshot.clone());
1126                }
1127
1128                last_event.edit_range = edit_range;
1129                last_event.new_snapshot = new_snapshot;
1130                last_event.last_edit_time = Some(now);
1131                return;
1132            }
1133        }
1134
1135        if let Some(event) = project_state.last_event.take() {
1136            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1137                if events.len() + 1 >= EVENT_COUNT_MAX {
1138                    events.pop_front();
1139                }
1140                events.push_back(event);
1141            }
1142        }
1143
1144        project_state.last_event = Some(LastEvent {
1145            old_file,
1146            new_file,
1147            old_snapshot,
1148            new_snapshot,
1149            edit_range,
1150            snapshot_after_last_editing_pause: None,
1151            last_edit_time: Some(now),
1152        });
1153    }
1154
1155    fn prediction_at(
1156        &mut self,
1157        buffer: &Entity<Buffer>,
1158        position: Option<language::Anchor>,
1159        project: &Entity<Project>,
1160        cx: &App,
1161    ) -> Option<BufferEditPrediction<'_>> {
1162        let project_state = self.projects.get_mut(&project.entity_id())?;
1163        if let Some(position) = position
1164            && let Some(buffer) = project_state
1165                .registered_buffers
1166                .get_mut(&buffer.entity_id())
1167        {
1168            buffer.last_position = Some(position);
1169        }
1170
1171        let CurrentEditPrediction {
1172            requested_by,
1173            prediction,
1174            ..
1175        } = project_state.current_prediction.as_ref()?;
1176
1177        if prediction.targets_buffer(buffer.read(cx)) {
1178            Some(BufferEditPrediction::Local { prediction })
1179        } else {
1180            let show_jump = match requested_by {
1181                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1182                    requested_by_buffer_id == &buffer.entity_id()
1183                }
1184                PredictionRequestedBy::DiagnosticsUpdate => true,
1185            };
1186
1187            if show_jump {
1188                Some(BufferEditPrediction::Jump { prediction })
1189            } else {
1190                None
1191            }
1192        }
1193    }
1194
1195    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1196        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1197            return;
1198        };
1199
1200        let Some(current_prediction) = project_state.current_prediction.take() else {
1201            return;
1202        };
1203
1204        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1205            project_state.cancel_pending_prediction(pending_prediction, cx);
1206        }
1207
1208        match self.edit_prediction_model {
1209            EditPredictionModel::Sweep => {
1210                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1211            }
1212            EditPredictionModel::Mercury => {
1213                mercury::edit_prediction_accepted(
1214                    current_prediction.prediction.id,
1215                    self.client.http_client(),
1216                    cx,
1217                );
1218            }
1219            EditPredictionModel::Ollama => {}
1220            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1221                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1222            }
1223        }
1224    }
1225
1226    async fn handle_rejected_predictions(
1227        rx: UnboundedReceiver<EditPredictionRejection>,
1228        client: Arc<Client>,
1229        llm_token: LlmApiToken,
1230        app_version: Version,
1231        background_executor: BackgroundExecutor,
1232    ) {
1233        let mut rx = std::pin::pin!(rx.peekable());
1234        let mut batched = Vec::new();
1235
1236        while let Some(rejection) = rx.next().await {
1237            batched.push(rejection);
1238
1239            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1240                select_biased! {
1241                    next = rx.as_mut().peek().fuse() => {
1242                        if next.is_some() {
1243                            continue;
1244                        }
1245                    }
1246                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1247                }
1248            }
1249
1250            let url = client
1251                .http_client()
1252                .build_zed_llm_url("/predict_edits/reject", &[])
1253                .unwrap();
1254
1255            let flush_count = batched
1256                .len()
1257                // in case items have accumulated after failure
1258                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1259            let start = batched.len() - flush_count;
1260
1261            let body = RejectEditPredictionsBodyRef {
1262                rejections: &batched[start..],
1263            };
1264
1265            let result = Self::send_api_request::<()>(
1266                |builder| {
1267                    let req = builder
1268                        .uri(url.as_ref())
1269                        .body(serde_json::to_string(&body)?.into());
1270                    anyhow::Ok(req?)
1271                },
1272                client.clone(),
1273                llm_token.clone(),
1274                app_version.clone(),
1275                true,
1276            )
1277            .await;
1278
1279            if result.log_err().is_some() {
1280                batched.drain(start..);
1281            }
1282        }
1283    }
1284
1285    fn reject_current_prediction(
1286        &mut self,
1287        reason: EditPredictionRejectReason,
1288        project: &Entity<Project>,
1289        cx: &App,
1290    ) {
1291        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1292            project_state.pending_predictions.clear();
1293            if let Some(prediction) = project_state.current_prediction.take() {
1294                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1295            }
1296        };
1297    }
1298
1299    fn did_show_current_prediction(
1300        &mut self,
1301        project: &Entity<Project>,
1302        display_type: edit_prediction_types::SuggestionDisplayType,
1303        cx: &mut Context<Self>,
1304    ) {
1305        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1306            return;
1307        };
1308
1309        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1310            return;
1311        };
1312
1313        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1314        let previous_shown_with = current_prediction.shown_with;
1315
1316        if previous_shown_with.is_none() || !is_jump {
1317            current_prediction.shown_with = Some(display_type);
1318        }
1319
1320        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1321
1322        if is_first_non_jump_show {
1323            current_prediction.was_shown = true;
1324        }
1325
1326        let display_type_changed = previous_shown_with != Some(display_type);
1327
1328        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1329            sweep_ai::edit_prediction_shown(
1330                &self.sweep_ai,
1331                self.client.clone(),
1332                &current_prediction.prediction,
1333                display_type,
1334                cx,
1335            );
1336        }
1337
1338        if is_first_non_jump_show {
1339            self.shown_predictions
1340                .push_front(current_prediction.prediction.clone());
1341            if self.shown_predictions.len() > 50 {
1342                let completion = self.shown_predictions.pop_back().unwrap();
1343                self.rated_predictions.remove(&completion.id);
1344            }
1345        }
1346    }
1347
1348    fn reject_prediction(
1349        &mut self,
1350        prediction_id: EditPredictionId,
1351        reason: EditPredictionRejectReason,
1352        was_shown: bool,
1353        cx: &App,
1354    ) {
1355        match self.edit_prediction_model {
1356            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1357                self.reject_predictions_tx
1358                    .unbounded_send(EditPredictionRejection {
1359                        request_id: prediction_id.to_string(),
1360                        reason,
1361                        was_shown,
1362                    })
1363                    .log_err();
1364            }
1365            EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1366            EditPredictionModel::Mercury => {
1367                mercury::edit_prediction_rejected(
1368                    prediction_id,
1369                    was_shown,
1370                    reason,
1371                    self.client.http_client(),
1372                    cx,
1373                );
1374            }
1375        }
1376    }
1377
1378    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1379        self.projects
1380            .get(&project.entity_id())
1381            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1382    }
1383
1384    pub fn refresh_prediction_from_buffer(
1385        &mut self,
1386        project: Entity<Project>,
1387        buffer: Entity<Buffer>,
1388        position: language::Anchor,
1389        cx: &mut Context<Self>,
1390    ) {
1391        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1392            let Some(request_task) = this
1393                .update(cx, |this, cx| {
1394                    this.request_prediction(
1395                        &project,
1396                        &buffer,
1397                        position,
1398                        PredictEditsRequestTrigger::Other,
1399                        cx,
1400                    )
1401                })
1402                .log_err()
1403            else {
1404                return Task::ready(anyhow::Ok(None));
1405            };
1406
1407            cx.spawn(async move |_cx| {
1408                request_task.await.map(|prediction_result| {
1409                    prediction_result.map(|prediction_result| {
1410                        (
1411                            prediction_result,
1412                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1413                        )
1414                    })
1415                })
1416            })
1417        })
1418    }
1419
1420    pub fn refresh_prediction_from_diagnostics(
1421        &mut self,
1422        project: Entity<Project>,
1423        cx: &mut Context<Self>,
1424    ) {
1425        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1426            return;
1427        };
1428
1429        // Prefer predictions from buffer
1430        if project_state.current_prediction.is_some() {
1431            return;
1432        };
1433
1434        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1435            let Some((active_buffer, snapshot, cursor_point)) = this
1436                .read_with(cx, |this, cx| {
1437                    let project_state = this.projects.get(&project.entity_id())?;
1438                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1439                    let snapshot = buffer.read(cx).snapshot();
1440
1441                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1442                        return None;
1443                    }
1444
1445                    let cursor_point = position
1446                        .map(|pos| pos.to_point(&snapshot))
1447                        .unwrap_or_default();
1448
1449                    Some((buffer, snapshot, cursor_point))
1450                })
1451                .log_err()
1452                .flatten()
1453            else {
1454                return Task::ready(anyhow::Ok(None));
1455            };
1456
1457            cx.spawn(async move |cx| {
1458                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1459                    active_buffer,
1460                    &snapshot,
1461                    Default::default(),
1462                    cursor_point,
1463                    &project,
1464                    cx,
1465                )
1466                .await?
1467                else {
1468                    return anyhow::Ok(None);
1469                };
1470
1471                let Some(prediction_result) = this
1472                    .update(cx, |this, cx| {
1473                        this.request_prediction(
1474                            &project,
1475                            &jump_buffer,
1476                            jump_position,
1477                            PredictEditsRequestTrigger::Diagnostics,
1478                            cx,
1479                        )
1480                    })?
1481                    .await?
1482                else {
1483                    return anyhow::Ok(None);
1484                };
1485
1486                this.update(cx, |this, cx| {
1487                    Some((
1488                        if this
1489                            .get_or_init_project(&project, cx)
1490                            .current_prediction
1491                            .is_none()
1492                        {
1493                            prediction_result
1494                        } else {
1495                            EditPredictionResult {
1496                                id: prediction_result.id,
1497                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1498                            }
1499                        },
1500                        PredictionRequestedBy::DiagnosticsUpdate,
1501                    ))
1502                })
1503            })
1504        });
1505    }
1506
1507    fn predictions_enabled_at(
1508        snapshot: &BufferSnapshot,
1509        position: Option<language::Anchor>,
1510        cx: &App,
1511    ) -> bool {
1512        let file = snapshot.file();
1513        let all_settings = all_language_settings(file, cx);
1514        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1515            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1516        {
1517            return false;
1518        }
1519
1520        if let Some(last_position) = position {
1521            let settings = snapshot.settings_at(last_position, cx);
1522
1523            if !settings.edit_predictions_disabled_in.is_empty()
1524                && let Some(scope) = snapshot.language_scope_at(last_position)
1525                && let Some(scope_name) = scope.override_name()
1526                && settings
1527                    .edit_predictions_disabled_in
1528                    .iter()
1529                    .any(|s| s == scope_name)
1530            {
1531                return false;
1532            }
1533        }
1534
1535        true
1536    }
1537
1538    #[cfg(not(test))]
1539    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1540    #[cfg(test)]
1541    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1542
1543    fn queue_prediction_refresh(
1544        &mut self,
1545        project: Entity<Project>,
1546        throttle_entity: EntityId,
1547        cx: &mut Context<Self>,
1548        do_refresh: impl FnOnce(
1549            WeakEntity<Self>,
1550            &mut AsyncApp,
1551        )
1552            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1553        + 'static,
1554    ) {
1555        let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1556        let drop_on_cancel = is_ollama;
1557        let max_pending_predictions = if is_ollama { 1 } else { 2 };
1558        let project_state = self.get_or_init_project(&project, cx);
1559        let pending_prediction_id = project_state.next_pending_prediction_id;
1560        project_state.next_pending_prediction_id += 1;
1561        let last_request = project_state.last_prediction_refresh;
1562
1563        let task = cx.spawn(async move |this, cx| {
1564            if let Some((last_entity, last_timestamp)) = last_request
1565                && throttle_entity == last_entity
1566                && let Some(timeout) =
1567                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1568            {
1569                cx.background_executor().timer(timeout).await;
1570            }
1571
1572            // If this task was cancelled before the throttle timeout expired,
1573            // do not perform a request.
1574            let mut is_cancelled = true;
1575            this.update(cx, |this, cx| {
1576                let project_state = this.get_or_init_project(&project, cx);
1577                if !project_state
1578                    .cancelled_predictions
1579                    .remove(&pending_prediction_id)
1580                {
1581                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1582                    is_cancelled = false;
1583                }
1584            })
1585            .ok();
1586            if is_cancelled {
1587                return None;
1588            }
1589
1590            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1591            let new_prediction_id = new_prediction_result
1592                .as_ref()
1593                .map(|(prediction, _)| prediction.id.clone());
1594
1595            // When a prediction completes, remove it from the pending list, and cancel
1596            // any pending predictions that were enqueued before it.
1597            this.update(cx, |this, cx| {
1598                let project_state = this.get_or_init_project(&project, cx);
1599
1600                let is_cancelled = project_state
1601                    .cancelled_predictions
1602                    .remove(&pending_prediction_id);
1603
1604                let new_current_prediction = if !is_cancelled
1605                    && let Some((prediction_result, requested_by)) = new_prediction_result
1606                {
1607                    match prediction_result.prediction {
1608                        Ok(prediction) => {
1609                            let new_prediction = CurrentEditPrediction {
1610                                requested_by,
1611                                prediction,
1612                                was_shown: false,
1613                                shown_with: None,
1614                            };
1615
1616                            if let Some(current_prediction) =
1617                                project_state.current_prediction.as_ref()
1618                            {
1619                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1620                                {
1621                                    this.reject_current_prediction(
1622                                        EditPredictionRejectReason::Replaced,
1623                                        &project,
1624                                        cx,
1625                                    );
1626
1627                                    Some(new_prediction)
1628                                } else {
1629                                    this.reject_prediction(
1630                                        new_prediction.prediction.id,
1631                                        EditPredictionRejectReason::CurrentPreferred,
1632                                        false,
1633                                        cx,
1634                                    );
1635                                    None
1636                                }
1637                            } else {
1638                                Some(new_prediction)
1639                            }
1640                        }
1641                        Err(reject_reason) => {
1642                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1643                            None
1644                        }
1645                    }
1646                } else {
1647                    None
1648                };
1649
1650                let project_state = this.get_or_init_project(&project, cx);
1651
1652                if let Some(new_prediction) = new_current_prediction {
1653                    project_state.current_prediction = Some(new_prediction);
1654                }
1655
1656                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1657                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1658                    if pending_prediction.id == pending_prediction_id {
1659                        pending_predictions.remove(ix);
1660                        for pending_prediction in pending_predictions.drain(0..ix) {
1661                            project_state.cancel_pending_prediction(pending_prediction, cx)
1662                        }
1663                        break;
1664                    }
1665                }
1666                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1667                cx.notify();
1668            })
1669            .ok();
1670
1671            new_prediction_id
1672        });
1673
1674        if project_state.pending_predictions.len() < max_pending_predictions {
1675            project_state.pending_predictions.push(PendingPrediction {
1676                id: pending_prediction_id,
1677                task,
1678                drop_on_cancel,
1679            });
1680        } else {
1681            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1682            project_state.pending_predictions.push(PendingPrediction {
1683                id: pending_prediction_id,
1684                task,
1685                drop_on_cancel,
1686            });
1687            project_state.cancel_pending_prediction(pending_prediction, cx);
1688        }
1689    }
1690
1691    pub fn request_prediction(
1692        &mut self,
1693        project: &Entity<Project>,
1694        active_buffer: &Entity<Buffer>,
1695        position: language::Anchor,
1696        trigger: PredictEditsRequestTrigger,
1697        cx: &mut Context<Self>,
1698    ) -> Task<Result<Option<EditPredictionResult>>> {
1699        self.request_prediction_internal(
1700            project.clone(),
1701            active_buffer.clone(),
1702            position,
1703            trigger,
1704            cx.has_flag::<Zeta2FeatureFlag>(),
1705            cx,
1706        )
1707    }
1708
1709    fn request_prediction_internal(
1710        &mut self,
1711        project: Entity<Project>,
1712        active_buffer: Entity<Buffer>,
1713        position: language::Anchor,
1714        trigger: PredictEditsRequestTrigger,
1715        allow_jump: bool,
1716        cx: &mut Context<Self>,
1717    ) -> Task<Result<Option<EditPredictionResult>>> {
1718        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1719
1720        self.get_or_init_project(&project, cx);
1721        let project_state = self.projects.get(&project.entity_id()).unwrap();
1722        let stored_events = project_state.events_split_by_pause(cx);
1723        let has_events = !stored_events.is_empty();
1724        let events: Vec<Arc<zeta_prompt::Event>> =
1725            stored_events.into_iter().map(|e| e.event).collect();
1726        let debug_tx = project_state.debug_tx.clone();
1727
1728        let snapshot = active_buffer.read(cx).snapshot();
1729        let cursor_point = position.to_point(&snapshot);
1730        let current_offset = position.to_offset(&snapshot);
1731
1732        let mut user_actions: Vec<UserActionRecord> =
1733            project_state.user_actions.iter().cloned().collect();
1734
1735        if let Some(last_action) = user_actions.last() {
1736            if last_action.buffer_id == active_buffer.entity_id()
1737                && current_offset != last_action.offset
1738            {
1739                let timestamp_epoch_ms = SystemTime::now()
1740                    .duration_since(UNIX_EPOCH)
1741                    .map(|d| d.as_millis() as u64)
1742                    .unwrap_or(0);
1743                user_actions.push(UserActionRecord {
1744                    action_type: UserActionType::CursorMovement,
1745                    buffer_id: active_buffer.entity_id(),
1746                    line_number: cursor_point.row,
1747                    offset: current_offset,
1748                    timestamp_epoch_ms,
1749                });
1750            }
1751        }
1752        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1753        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1754        let diagnostic_search_range =
1755            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1756
1757        let related_files = self.context_for_project(&project, cx);
1758
1759        let inputs = EditPredictionModelInput {
1760            project: project.clone(),
1761            buffer: active_buffer.clone(),
1762            snapshot: snapshot.clone(),
1763            position,
1764            events,
1765            related_files,
1766            recent_paths: project_state.recent_paths.clone(),
1767            trigger,
1768            diagnostic_search_range: diagnostic_search_range.clone(),
1769            debug_tx,
1770            user_actions,
1771        };
1772
1773        let task = match &self.edit_prediction_model {
1774            EditPredictionModel::Zeta1 => {
1775                if should_send_testing_zeta2_request() {
1776                    let mut zeta2_inputs = inputs.clone();
1777                    zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1778                    zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach();
1779                }
1780                zeta1::request_prediction_with_zeta1(self, inputs, cx)
1781            }
1782            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1783            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1784            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1785            EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1786        };
1787
1788        cx.spawn(async move |this, cx| {
1789            let prediction = task.await?;
1790
1791            if prediction.is_none() && allow_jump {
1792                let cursor_point = position.to_point(&snapshot);
1793                if has_events
1794                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1795                        active_buffer.clone(),
1796                        &snapshot,
1797                        diagnostic_search_range,
1798                        cursor_point,
1799                        &project,
1800                        cx,
1801                    )
1802                    .await?
1803                {
1804                    return this
1805                        .update(cx, |this, cx| {
1806                            this.request_prediction_internal(
1807                                project,
1808                                jump_buffer,
1809                                jump_position,
1810                                trigger,
1811                                false,
1812                                cx,
1813                            )
1814                        })?
1815                        .await;
1816                }
1817
1818                return anyhow::Ok(None);
1819            }
1820
1821            Ok(prediction)
1822        })
1823    }
1824
1825    async fn next_diagnostic_location(
1826        active_buffer: Entity<Buffer>,
1827        active_buffer_snapshot: &BufferSnapshot,
1828        active_buffer_diagnostic_search_range: Range<Point>,
1829        active_buffer_cursor_point: Point,
1830        project: &Entity<Project>,
1831        cx: &mut AsyncApp,
1832    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1833        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1834        let mut jump_location = active_buffer_snapshot
1835            .diagnostic_groups(None)
1836            .into_iter()
1837            .filter_map(|(_, group)| {
1838                let range = &group.entries[group.primary_ix]
1839                    .range
1840                    .to_point(&active_buffer_snapshot);
1841                if range.overlaps(&active_buffer_diagnostic_search_range) {
1842                    None
1843                } else {
1844                    Some(range.start)
1845                }
1846            })
1847            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1848            .map(|position| {
1849                (
1850                    active_buffer.clone(),
1851                    active_buffer_snapshot.anchor_before(position),
1852                )
1853            });
1854
1855        if jump_location.is_none() {
1856            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1857                let file = buffer.file()?;
1858
1859                Some(ProjectPath {
1860                    worktree_id: file.worktree_id(cx),
1861                    path: file.path().clone(),
1862                })
1863            });
1864
1865            let buffer_task = project.update(cx, |project, cx| {
1866                let (path, _, _) = project
1867                    .diagnostic_summaries(false, cx)
1868                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1869                    .max_by_key(|(path, _, _)| {
1870                        // find the buffer with errors that shares most parent directories
1871                        path.path
1872                            .components()
1873                            .zip(
1874                                active_buffer_path
1875                                    .as_ref()
1876                                    .map(|p| p.path.components())
1877                                    .unwrap_or_default(),
1878                            )
1879                            .take_while(|(a, b)| a == b)
1880                            .count()
1881                    })?;
1882
1883                Some(project.open_buffer(path, cx))
1884            });
1885
1886            if let Some(buffer_task) = buffer_task {
1887                let closest_buffer = buffer_task.await?;
1888
1889                jump_location = closest_buffer
1890                    .read_with(cx, |buffer, _cx| {
1891                        buffer
1892                            .buffer_diagnostics(None)
1893                            .into_iter()
1894                            .min_by_key(|entry| entry.diagnostic.severity)
1895                            .map(|entry| entry.range.start)
1896                    })
1897                    .map(|position| (closest_buffer, position));
1898            }
1899        }
1900
1901        anyhow::Ok(jump_location)
1902    }
1903
1904    async fn send_raw_llm_request(
1905        request: RawCompletionRequest,
1906        client: Arc<Client>,
1907        custom_url: Option<Arc<Url>>,
1908        llm_token: LlmApiToken,
1909        app_version: Version,
1910    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1911        let url = if let Some(custom_url) = custom_url {
1912            custom_url.as_ref().clone()
1913        } else {
1914            client
1915                .http_client()
1916                .build_zed_llm_url("/predict_edits/raw", &[])?
1917        };
1918
1919        Self::send_api_request(
1920            |builder| {
1921                let req = builder
1922                    .uri(url.as_ref())
1923                    .body(serde_json::to_string(&request)?.into());
1924                Ok(req?)
1925            },
1926            client,
1927            llm_token,
1928            app_version,
1929            true,
1930        )
1931        .await
1932    }
1933
1934    pub(crate) async fn send_v3_request(
1935        input: ZetaPromptInput,
1936        client: Arc<Client>,
1937        llm_token: LlmApiToken,
1938        app_version: Version,
1939        trigger: PredictEditsRequestTrigger,
1940    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1941        let url = client
1942            .http_client()
1943            .build_zed_llm_url("/predict_edits/v3", &[])?;
1944
1945        let request = PredictEditsV3Request { input, trigger };
1946
1947        let json_bytes = serde_json::to_vec(&request)?;
1948        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
1949
1950        Self::send_api_request(
1951            |builder| {
1952                let req = builder
1953                    .uri(url.as_ref())
1954                    .header("Content-Encoding", "zstd")
1955                    .body(compressed.clone().into());
1956                Ok(req?)
1957            },
1958            client,
1959            llm_token,
1960            app_version,
1961            true,
1962        )
1963        .await
1964    }
1965
1966    fn handle_api_response<T>(
1967        this: &WeakEntity<Self>,
1968        response: Result<(T, Option<EditPredictionUsage>)>,
1969        cx: &mut gpui::AsyncApp,
1970    ) -> Result<T> {
1971        match response {
1972            Ok((data, usage)) => {
1973                if let Some(usage) = usage {
1974                    this.update(cx, |this, cx| {
1975                        this.user_store.update(cx, |user_store, cx| {
1976                            user_store.update_edit_prediction_usage(usage, cx);
1977                        });
1978                    })
1979                    .ok();
1980                }
1981                Ok(data)
1982            }
1983            Err(err) => {
1984                if err.is::<ZedUpdateRequiredError>() {
1985                    cx.update(|cx| {
1986                        this.update(cx, |this, _cx| {
1987                            this.update_required = true;
1988                        })
1989                        .ok();
1990
1991                        let error_message: SharedString = err.to_string().into();
1992                        show_app_notification(
1993                            NotificationId::unique::<ZedUpdateRequiredError>(),
1994                            cx,
1995                            move |cx| {
1996                                cx.new(|cx| {
1997                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1998                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1999                                })
2000                            },
2001                        );
2002                    });
2003                }
2004                Err(err)
2005            }
2006        }
2007    }
2008
2009    async fn send_api_request<Res>(
2010        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2011        client: Arc<Client>,
2012        llm_token: LlmApiToken,
2013        app_version: Version,
2014        require_auth: bool,
2015    ) -> Result<(Res, Option<EditPredictionUsage>)>
2016    where
2017        Res: DeserializeOwned,
2018    {
2019        let http_client = client.http_client();
2020
2021        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2022            Some(custom_token)
2023        } else if require_auth {
2024            Some(llm_token.acquire(&client).await?)
2025        } else {
2026            llm_token.acquire(&client).await.ok()
2027        };
2028        let mut did_retry = false;
2029
2030        loop {
2031            let request_builder = http_client::Request::builder().method(Method::POST);
2032
2033            let mut request_builder = request_builder
2034                .header("Content-Type", "application/json")
2035                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2036
2037            // Only add Authorization header if we have a token
2038            if let Some(ref token_value) = token {
2039                request_builder =
2040                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2041            }
2042
2043            let request = build(request_builder)?;
2044
2045            let mut response = http_client.send(request).await?;
2046
2047            if let Some(minimum_required_version) = response
2048                .headers()
2049                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2050                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2051            {
2052                anyhow::ensure!(
2053                    app_version >= minimum_required_version,
2054                    ZedUpdateRequiredError {
2055                        minimum_version: minimum_required_version
2056                    }
2057                );
2058            }
2059
2060            if response.status().is_success() {
2061                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2062
2063                let mut body = Vec::new();
2064                response.body_mut().read_to_end(&mut body).await?;
2065                return Ok((serde_json::from_slice(&body)?, usage));
2066            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2067                did_retry = true;
2068                token = Some(llm_token.refresh(&client).await?);
2069            } else {
2070                let mut body = String::new();
2071                response.body_mut().read_to_string(&mut body).await?;
2072                anyhow::bail!(
2073                    "Request failed with status: {:?}\nBody: {}",
2074                    response.status(),
2075                    body
2076                );
2077            }
2078        }
2079    }
2080
2081    pub fn refresh_context(
2082        &mut self,
2083        project: &Entity<Project>,
2084        buffer: &Entity<language::Buffer>,
2085        cursor_position: language::Anchor,
2086        cx: &mut Context<Self>,
2087    ) {
2088        self.get_or_init_project(project, cx)
2089            .context
2090            .update(cx, |store, cx| {
2091                store.refresh(buffer.clone(), cursor_position, cx);
2092            });
2093    }
2094
2095    #[cfg(feature = "cli-support")]
2096    pub fn set_context_for_buffer(
2097        &mut self,
2098        project: &Entity<Project>,
2099        related_files: Vec<RelatedFile>,
2100        cx: &mut Context<Self>,
2101    ) {
2102        self.get_or_init_project(project, cx)
2103            .context
2104            .update(cx, |store, cx| {
2105                store.set_related_files(related_files, cx);
2106            });
2107    }
2108
2109    #[cfg(feature = "cli-support")]
2110    pub fn set_recent_paths_for_project(
2111        &mut self,
2112        project: &Entity<Project>,
2113        paths: impl IntoIterator<Item = project::ProjectPath>,
2114        cx: &mut Context<Self>,
2115    ) {
2116        let project_state = self.get_or_init_project(project, cx);
2117        project_state.recent_paths = paths.into_iter().collect();
2118    }
2119
2120    fn is_file_open_source(
2121        &self,
2122        project: &Entity<Project>,
2123        file: &Arc<dyn File>,
2124        cx: &App,
2125    ) -> bool {
2126        if !file.is_local() || file.is_private() {
2127            return false;
2128        }
2129        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2130            return false;
2131        };
2132        project_state
2133            .license_detection_watchers
2134            .get(&file.worktree_id(cx))
2135            .as_ref()
2136            .is_some_and(|watcher| watcher.is_project_open_source())
2137    }
2138
2139    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2140        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2141    }
2142
2143    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2144        if !self.data_collection_choice.is_enabled(cx) {
2145            return false;
2146        }
2147        events.iter().all(|event| {
2148            matches!(
2149                event.as_ref(),
2150                zeta_prompt::Event::BufferChange {
2151                    in_open_source_repo: true,
2152                    ..
2153                }
2154            )
2155        })
2156    }
2157
2158    fn load_data_collection_choice() -> DataCollectionChoice {
2159        let choice = KEY_VALUE_STORE
2160            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2161            .log_err()
2162            .flatten();
2163
2164        match choice.as_deref() {
2165            Some("true") => DataCollectionChoice::Enabled,
2166            Some("false") => DataCollectionChoice::Disabled,
2167            Some(_) => {
2168                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2169                DataCollectionChoice::NotAnswered
2170            }
2171            None => DataCollectionChoice::NotAnswered,
2172        }
2173    }
2174
2175    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2176        self.data_collection_choice = self.data_collection_choice.toggle();
2177        let new_choice = self.data_collection_choice;
2178        let is_enabled = new_choice.is_enabled(cx);
2179        db::write_and_log(cx, move || {
2180            KEY_VALUE_STORE.write_kvp(
2181                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2182                is_enabled.to_string(),
2183            )
2184        });
2185    }
2186
2187    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2188        self.shown_predictions.iter()
2189    }
2190
2191    pub fn shown_completions_len(&self) -> usize {
2192        self.shown_predictions.len()
2193    }
2194
2195    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2196        self.rated_predictions.contains(id)
2197    }
2198
2199    pub fn rate_prediction(
2200        &mut self,
2201        prediction: &EditPrediction,
2202        rating: EditPredictionRating,
2203        feedback: String,
2204        cx: &mut Context<Self>,
2205    ) {
2206        self.rated_predictions.insert(prediction.id.clone());
2207        telemetry::event!(
2208            "Edit Prediction Rated",
2209            request_id = prediction.id.to_string(),
2210            rating,
2211            inputs = prediction.inputs,
2212            output = prediction
2213                .edit_preview
2214                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2215            feedback
2216        );
2217        self.client.telemetry().flush_events().detach();
2218        cx.notify();
2219    }
2220}
2221
2222pub(crate) fn filter_redundant_excerpts(
2223    mut related_files: Vec<RelatedFile>,
2224    cursor_path: &Path,
2225    cursor_row_range: Range<u32>,
2226) -> Vec<RelatedFile> {
2227    for file in &mut related_files {
2228        if file.path.as_ref() == cursor_path {
2229            file.excerpts.retain(|excerpt| {
2230                excerpt.row_range.start < cursor_row_range.start
2231                    || excerpt.row_range.end > cursor_row_range.end
2232            });
2233        }
2234    }
2235    related_files.retain(|file| !file.excerpts.is_empty());
2236    related_files
2237}
2238
2239#[derive(Error, Debug)]
2240#[error(
2241    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2242)]
2243pub struct ZedUpdateRequiredError {
2244    minimum_version: Version,
2245}
2246
2247#[derive(Debug, Clone, Copy)]
2248pub enum DataCollectionChoice {
2249    NotAnswered,
2250    Enabled,
2251    Disabled,
2252}
2253
2254impl DataCollectionChoice {
2255    pub fn is_enabled(self, cx: &App) -> bool {
2256        if cx.is_staff() {
2257            return true;
2258        }
2259        match self {
2260            Self::Enabled => true,
2261            Self::NotAnswered | Self::Disabled => false,
2262        }
2263    }
2264
2265    #[must_use]
2266    pub fn toggle(&self) -> DataCollectionChoice {
2267        match self {
2268            Self::Enabled => Self::Disabled,
2269            Self::Disabled => Self::Enabled,
2270            Self::NotAnswered => Self::Enabled,
2271        }
2272    }
2273}
2274
2275impl From<bool> for DataCollectionChoice {
2276    fn from(value: bool) -> Self {
2277        match value {
2278            true => DataCollectionChoice::Enabled,
2279            false => DataCollectionChoice::Disabled,
2280        }
2281    }
2282}
2283
2284struct ZedPredictUpsell;
2285
2286impl Dismissable for ZedPredictUpsell {
2287    const KEY: &'static str = "dismissed-edit-predict-upsell";
2288
2289    fn dismissed() -> bool {
2290        // To make this backwards compatible with older versions of Zed, we
2291        // check if the user has seen the previous Edit Prediction Onboarding
2292        // before, by checking the data collection choice which was written to
2293        // the database once the user clicked on "Accept and Enable"
2294        if KEY_VALUE_STORE
2295            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2296            .log_err()
2297            .is_some_and(|s| s.is_some())
2298        {
2299            return true;
2300        }
2301
2302        KEY_VALUE_STORE
2303            .read_kvp(Self::KEY)
2304            .log_err()
2305            .is_some_and(|s| s.is_some())
2306    }
2307}
2308
2309pub fn should_show_upsell_modal() -> bool {
2310    !ZedPredictUpsell::dismissed()
2311}
2312
2313pub fn init(cx: &mut App) {
2314    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2315        workspace.register_action(
2316            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2317                ZedPredictModal::toggle(
2318                    workspace,
2319                    workspace.user_store().clone(),
2320                    workspace.client().clone(),
2321                    window,
2322                    cx,
2323                )
2324            },
2325        );
2326
2327        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2328            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2329                settings
2330                    .project
2331                    .all_languages
2332                    .edit_predictions
2333                    .get_or_insert_default()
2334                    .provider = Some(EditPredictionProvider::None)
2335            });
2336        });
2337        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2338            EditPredictionStore::try_global(cx).and_then(|store| {
2339                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2340            })
2341        }
2342
2343        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2344            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2345                copilot_ui::initiate_sign_in(copilot, window, cx);
2346            }
2347        });
2348        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2349            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2350                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2351            }
2352        });
2353        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2354            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2355                copilot_ui::initiate_sign_out(copilot, window, cx);
2356            }
2357        });
2358    })
2359    .detach();
2360}