edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use std::env;
  40use text::Edit;
  41use workspace::Workspace;
  42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  43
  44use std::mem;
  45use std::ops::Range;
  46use std::path::Path;
  47use std::rc::Rc;
  48use std::str::FromStr as _;
  49use std::sync::Arc;
  50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59pub mod ollama;
  60mod onboarding_modal;
  61pub mod open_ai_response;
  62mod prediction;
  63pub mod sweep_ai;
  64
  65pub mod udiff;
  66
  67mod capture_example;
  68mod zed_edit_prediction_delegate;
  69pub mod zeta1;
  70pub mod zeta2;
  71
  72#[cfg(test)]
  73mod edit_prediction_tests;
  74
  75use crate::license_detection::LicenseDetectionWatcher;
  76use crate::mercury::Mercury;
  77use crate::ollama::Ollama;
  78use crate::onboarding_modal::ZedPredictModal;
  79pub use crate::prediction::EditPrediction;
  80pub use crate::prediction::EditPredictionId;
  81use crate::prediction::EditPredictionResult;
  82pub use crate::sweep_ai::SweepAi;
  83pub use capture_example::capture_example;
  84pub use language_model::ApiKeyState;
  85pub use telemetry_events::EditPredictionRating;
  86pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  87
  88actions!(
  89    edit_prediction,
  90    [
  91        /// Resets the edit prediction onboarding state.
  92        ResetOnboarding,
  93        /// Clears the edit prediction history.
  94        ClearHistory,
  95    ]
  96);
  97
  98/// Maximum number of events to track.
  99const EVENT_COUNT_MAX: usize = 6;
 100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 101const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 103const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 104
 105pub struct Zeta2FeatureFlag;
 106
 107impl FeatureFlag for Zeta2FeatureFlag {
 108    const NAME: &'static str = "zeta2";
 109
 110    fn enabled_for_staff() -> bool {
 111        true
 112    }
 113}
 114
 115#[derive(Clone)]
 116struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 117
 118impl Global for EditPredictionStoreGlobal {}
 119
 120/// Configuration for using the raw Zeta2 endpoint.
 121/// When set, the client uses the raw endpoint and constructs the prompt itself.
 122/// The version is also used as the Baseten environment name (lowercased).
 123#[derive(Clone)]
 124pub struct Zeta2RawConfig {
 125    pub model_id: Option<String>,
 126    pub format: ZetaFormat,
 127}
 128
 129pub struct EditPredictionStore {
 130    client: Arc<Client>,
 131    user_store: Entity<UserStore>,
 132    llm_token: LlmApiToken,
 133    _llm_token_subscription: Subscription,
 134    projects: HashMap<EntityId, ProjectState>,
 135    update_required: bool,
 136    edit_prediction_model: EditPredictionModel,
 137    zeta2_raw_config: Option<Zeta2RawConfig>,
 138    pub sweep_ai: SweepAi,
 139    pub mercury: Mercury,
 140    pub ollama: Ollama,
 141    data_collection_choice: DataCollectionChoice,
 142    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 143    shown_predictions: VecDeque<EditPrediction>,
 144    rated_predictions: HashSet<EditPredictionId>,
 145}
 146
 147#[derive(Copy, Clone, Default, PartialEq, Eq)]
 148pub enum EditPredictionModel {
 149    #[default]
 150    Zeta1,
 151    Zeta2,
 152    Sweep,
 153    Mercury,
 154    Ollama,
 155}
 156
 157#[derive(Clone)]
 158pub struct EditPredictionModelInput {
 159    project: Entity<Project>,
 160    buffer: Entity<Buffer>,
 161    snapshot: BufferSnapshot,
 162    position: Anchor,
 163    events: Vec<Arc<zeta_prompt::Event>>,
 164    related_files: Vec<RelatedFile>,
 165    recent_paths: VecDeque<ProjectPath>,
 166    trigger: PredictEditsRequestTrigger,
 167    diagnostic_search_range: Range<Point>,
 168    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 169    pub user_actions: Vec<UserActionRecord>,
 170}
 171
 172#[derive(Debug)]
 173pub enum DebugEvent {
 174    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 175    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 176    EditPredictionStarted(EditPredictionStartedDebugEvent),
 177    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 178}
 179
 180#[derive(Debug)]
 181pub struct ContextRetrievalStartedDebugEvent {
 182    pub project_entity_id: EntityId,
 183    pub timestamp: Instant,
 184    pub search_prompt: String,
 185}
 186
 187#[derive(Debug)]
 188pub struct ContextRetrievalFinishedDebugEvent {
 189    pub project_entity_id: EntityId,
 190    pub timestamp: Instant,
 191    pub metadata: Vec<(&'static str, SharedString)>,
 192}
 193
 194#[derive(Debug)]
 195pub struct EditPredictionStartedDebugEvent {
 196    pub buffer: WeakEntity<Buffer>,
 197    pub position: Anchor,
 198    pub prompt: Option<String>,
 199}
 200
 201#[derive(Debug)]
 202pub struct EditPredictionFinishedDebugEvent {
 203    pub buffer: WeakEntity<Buffer>,
 204    pub position: Anchor,
 205    pub model_output: Option<String>,
 206}
 207
 208const USER_ACTION_HISTORY_SIZE: usize = 16;
 209
 210#[derive(Clone, Debug)]
 211pub struct UserActionRecord {
 212    pub action_type: UserActionType,
 213    pub buffer_id: EntityId,
 214    pub line_number: u32,
 215    pub offset: usize,
 216    pub timestamp_epoch_ms: u64,
 217}
 218
 219#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 220pub enum UserActionType {
 221    InsertChar,
 222    InsertSelection,
 223    DeleteChar,
 224    DeleteSelection,
 225    CursorMovement,
 226}
 227
 228/// An event with associated metadata for reconstructing buffer state.
 229#[derive(Clone)]
 230pub struct StoredEvent {
 231    pub event: Arc<zeta_prompt::Event>,
 232    pub old_snapshot: TextBufferSnapshot,
 233    pub edit_range: Range<Anchor>,
 234}
 235
 236impl StoredEvent {
 237    fn can_merge(
 238        &self,
 239        next_old_event: &&&StoredEvent,
 240        new_snapshot: &TextBufferSnapshot,
 241        last_edit_range: &Range<Anchor>,
 242    ) -> bool {
 243        // Events must be for the same buffer
 244        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 245            return false;
 246        }
 247        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 248            return false;
 249        }
 250
 251        let a_is_predicted = matches!(
 252            self.event.as_ref(),
 253            zeta_prompt::Event::BufferChange {
 254                predicted: true,
 255                ..
 256            }
 257        );
 258        let b_is_predicted = matches!(
 259            next_old_event.event.as_ref(),
 260            zeta_prompt::Event::BufferChange {
 261                predicted: true,
 262                ..
 263            }
 264        );
 265
 266        // If events come from the same source (both predicted or both manual) then
 267        // we would have coalesced them already.
 268        if a_is_predicted == b_is_predicted {
 269            return false;
 270        }
 271
 272        let left_range = self.edit_range.to_point(new_snapshot);
 273        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 274        let latest_range = last_edit_range.to_point(&new_snapshot);
 275
 276        // Events near to the latest edit are not merged if their sources differ.
 277        if lines_between_ranges(&left_range, &latest_range)
 278            .min(lines_between_ranges(&right_range, &latest_range))
 279            <= CHANGE_GROUPING_LINE_SPAN
 280        {
 281            return false;
 282        }
 283
 284        // Events that are distant from each other are not merged.
 285        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 286            return false;
 287        }
 288
 289        true
 290    }
 291}
 292
 293fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 294    if left.start > right.end {
 295        return left.start.row - right.end.row;
 296    }
 297    if right.start > left.end {
 298        return right.start.row - left.end.row;
 299    }
 300    0
 301}
 302
 303struct ProjectState {
 304    events: VecDeque<StoredEvent>,
 305    last_event: Option<LastEvent>,
 306    recent_paths: VecDeque<ProjectPath>,
 307    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 308    current_prediction: Option<CurrentEditPrediction>,
 309    next_pending_prediction_id: usize,
 310    pending_predictions: ArrayVec<PendingPrediction, 2>,
 311    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 312    last_prediction_refresh: Option<(EntityId, Instant)>,
 313    cancelled_predictions: HashSet<usize>,
 314    context: Entity<RelatedExcerptStore>,
 315    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 316    user_actions: VecDeque<UserActionRecord>,
 317    _subscriptions: [gpui::Subscription; 2],
 318    copilot: Option<Entity<Copilot>>,
 319}
 320
 321impl ProjectState {
 322    fn record_user_action(&mut self, action: UserActionRecord) {
 323        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 324            self.user_actions.pop_front();
 325        }
 326        self.user_actions.push_back(action);
 327    }
 328
 329    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 330        self.events
 331            .iter()
 332            .cloned()
 333            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 334                let (one, two) = event.split_by_pause();
 335                let one = one.finalize(&self.license_detection_watchers, cx);
 336                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 337                one.into_iter().chain(two)
 338            }))
 339            .collect()
 340    }
 341
 342    fn cancel_pending_prediction(
 343        &mut self,
 344        pending_prediction: PendingPrediction,
 345        cx: &mut Context<EditPredictionStore>,
 346    ) {
 347        self.cancelled_predictions.insert(pending_prediction.id);
 348
 349        if pending_prediction.drop_on_cancel {
 350            drop(pending_prediction.task);
 351        } else {
 352            cx.spawn(async move |this, cx| {
 353                let Some(prediction_id) = pending_prediction.task.await else {
 354                    return;
 355                };
 356
 357                this.update(cx, |this, cx| {
 358                    this.reject_prediction(
 359                        prediction_id,
 360                        EditPredictionRejectReason::Canceled,
 361                        false,
 362                        cx,
 363                    );
 364                })
 365                .ok();
 366            })
 367            .detach()
 368        }
 369    }
 370
 371    fn active_buffer(
 372        &self,
 373        project: &Entity<Project>,
 374        cx: &App,
 375    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 376        let project = project.read(cx);
 377        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 378        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 379        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 380        Some((active_buffer, registered_buffer.last_position))
 381    }
 382}
 383
 384#[derive(Debug, Clone)]
 385struct CurrentEditPrediction {
 386    pub requested_by: PredictionRequestedBy,
 387    pub prediction: EditPrediction,
 388    pub was_shown: bool,
 389    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 390}
 391
 392impl CurrentEditPrediction {
 393    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 394        let Some(new_edits) = self
 395            .prediction
 396            .interpolate(&self.prediction.buffer.read(cx))
 397        else {
 398            return false;
 399        };
 400
 401        if self.prediction.buffer != old_prediction.prediction.buffer {
 402            return true;
 403        }
 404
 405        let Some(old_edits) = old_prediction
 406            .prediction
 407            .interpolate(&old_prediction.prediction.buffer.read(cx))
 408        else {
 409            return true;
 410        };
 411
 412        let requested_by_buffer_id = self.requested_by.buffer_id();
 413
 414        // This reduces the occurrence of UI thrash from replacing edits
 415        //
 416        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 417        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 418            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 419            && old_edits.len() == 1
 420            && new_edits.len() == 1
 421        {
 422            let (old_range, old_text) = &old_edits[0];
 423            let (new_range, new_text) = &new_edits[0];
 424            new_range == old_range && new_text.starts_with(old_text.as_ref())
 425        } else {
 426            true
 427        }
 428    }
 429}
 430
 431#[derive(Debug, Clone)]
 432enum PredictionRequestedBy {
 433    DiagnosticsUpdate,
 434    Buffer(EntityId),
 435}
 436
 437impl PredictionRequestedBy {
 438    pub fn buffer_id(&self) -> Option<EntityId> {
 439        match self {
 440            PredictionRequestedBy::DiagnosticsUpdate => None,
 441            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 442        }
 443    }
 444}
 445
 446#[derive(Debug)]
 447struct PendingPrediction {
 448    id: usize,
 449    task: Task<Option<EditPredictionId>>,
 450    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 451    /// If false, the task is awaited to completion so rejection can be reported.
 452    drop_on_cancel: bool,
 453}
 454
 455/// A prediction from the perspective of a buffer.
 456#[derive(Debug)]
 457enum BufferEditPrediction<'a> {
 458    Local { prediction: &'a EditPrediction },
 459    Jump { prediction: &'a EditPrediction },
 460}
 461
 462#[cfg(test)]
 463impl std::ops::Deref for BufferEditPrediction<'_> {
 464    type Target = EditPrediction;
 465
 466    fn deref(&self) -> &Self::Target {
 467        match self {
 468            BufferEditPrediction::Local { prediction } => prediction,
 469            BufferEditPrediction::Jump { prediction } => prediction,
 470        }
 471    }
 472}
 473
 474struct RegisteredBuffer {
 475    file: Option<Arc<dyn File>>,
 476    snapshot: TextBufferSnapshot,
 477    last_position: Option<Anchor>,
 478    _subscriptions: [gpui::Subscription; 2],
 479}
 480
 481#[derive(Clone)]
 482struct LastEvent {
 483    old_snapshot: TextBufferSnapshot,
 484    new_snapshot: TextBufferSnapshot,
 485    old_file: Option<Arc<dyn File>>,
 486    new_file: Option<Arc<dyn File>>,
 487    edit_range: Option<Range<Anchor>>,
 488    predicted: bool,
 489    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 490    last_edit_time: Option<Instant>,
 491}
 492
 493impl LastEvent {
 494    pub fn finalize(
 495        &self,
 496        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 497        cx: &App,
 498    ) -> Option<StoredEvent> {
 499        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 500        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 501
 502        let in_open_source_repo =
 503            [self.new_file.as_ref(), self.old_file.as_ref()]
 504                .iter()
 505                .all(|file| {
 506                    file.is_some_and(|file| {
 507                        license_detection_watchers
 508                            .get(&file.worktree_id(cx))
 509                            .is_some_and(|watcher| watcher.is_project_open_source())
 510                    })
 511                });
 512
 513        let (diff, edit_range) =
 514            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 515
 516        if path == old_path && diff.is_empty() {
 517            None
 518        } else {
 519            Some(StoredEvent {
 520                event: Arc::new(zeta_prompt::Event::BufferChange {
 521                    old_path,
 522                    path,
 523                    diff,
 524                    in_open_source_repo,
 525                    predicted: self.predicted,
 526                }),
 527                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 528                    ..self.new_snapshot.anchor_before(edit_range.end),
 529                old_snapshot: self.old_snapshot.clone(),
 530            })
 531        }
 532    }
 533
 534    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 535        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 536            return (self.clone(), None);
 537        };
 538
 539        let before = LastEvent {
 540            old_snapshot: self.old_snapshot.clone(),
 541            new_snapshot: boundary_snapshot.clone(),
 542            old_file: self.old_file.clone(),
 543            new_file: self.new_file.clone(),
 544            edit_range: None,
 545            predicted: self.predicted,
 546            snapshot_after_last_editing_pause: None,
 547            last_edit_time: self.last_edit_time,
 548        };
 549
 550        let after = LastEvent {
 551            old_snapshot: boundary_snapshot.clone(),
 552            new_snapshot: self.new_snapshot.clone(),
 553            old_file: self.old_file.clone(),
 554            new_file: self.new_file.clone(),
 555            edit_range: None,
 556            predicted: self.predicted,
 557            snapshot_after_last_editing_pause: None,
 558            last_edit_time: self.last_edit_time,
 559        };
 560
 561        (before, Some(after))
 562    }
 563}
 564
 565pub(crate) fn compute_diff_between_snapshots(
 566    old_snapshot: &TextBufferSnapshot,
 567    new_snapshot: &TextBufferSnapshot,
 568) -> Option<(String, Range<Point>)> {
 569    let edits: Vec<Edit<usize>> = new_snapshot
 570        .edits_since::<usize>(&old_snapshot.version)
 571        .collect();
 572
 573    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 574
 575    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 576    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 577    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 578    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 579
 580    const CONTEXT_LINES: u32 = 3;
 581
 582    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 583    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 584    let old_context_end_row =
 585        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 586    let new_context_end_row =
 587        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 588
 589    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 590    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 591    let old_end_line_offset = old_snapshot
 592        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 593    let new_end_line_offset = new_snapshot
 594        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 595    let old_edit_range = old_start_line_offset..old_end_line_offset;
 596    let new_edit_range = new_start_line_offset..new_end_line_offset;
 597
 598    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 599    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 600
 601    let diff = language::unified_diff_with_offsets(
 602        &old_region_text,
 603        &new_region_text,
 604        old_context_start_row,
 605        new_context_start_row,
 606    );
 607
 608    Some((diff, new_start_point..new_end_point))
 609}
 610
 611fn buffer_path_with_id_fallback(
 612    file: Option<&Arc<dyn File>>,
 613    snapshot: &TextBufferSnapshot,
 614    cx: &App,
 615) -> Arc<Path> {
 616    if let Some(file) = file {
 617        file.full_path(cx).into()
 618    } else {
 619        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 620    }
 621}
 622
 623impl EditPredictionStore {
 624    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 625        cx.try_global::<EditPredictionStoreGlobal>()
 626            .map(|global| global.0.clone())
 627    }
 628
 629    pub fn global(
 630        client: &Arc<Client>,
 631        user_store: &Entity<UserStore>,
 632        cx: &mut App,
 633    ) -> Entity<Self> {
 634        cx.try_global::<EditPredictionStoreGlobal>()
 635            .map(|global| global.0.clone())
 636            .unwrap_or_else(|| {
 637                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 638                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 639                ep_store
 640            })
 641    }
 642
 643    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 644        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 645        let data_collection_choice = Self::load_data_collection_choice();
 646
 647        let llm_token = LlmApiToken::default();
 648
 649        let (reject_tx, reject_rx) = mpsc::unbounded();
 650        cx.background_spawn({
 651            let client = client.clone();
 652            let llm_token = llm_token.clone();
 653            let app_version = AppVersion::global(cx);
 654            let background_executor = cx.background_executor().clone();
 655            async move {
 656                Self::handle_rejected_predictions(
 657                    reject_rx,
 658                    client,
 659                    llm_token,
 660                    app_version,
 661                    background_executor,
 662                )
 663                .await
 664            }
 665        })
 666        .detach();
 667
 668        let this = Self {
 669            projects: HashMap::default(),
 670            client,
 671            user_store,
 672            llm_token,
 673            _llm_token_subscription: cx.subscribe(
 674                &refresh_llm_token_listener,
 675                |this, _listener, _event, cx| {
 676                    let client = this.client.clone();
 677                    let llm_token = this.llm_token.clone();
 678                    cx.spawn(async move |_this, _cx| {
 679                        llm_token.refresh(&client).await?;
 680                        anyhow::Ok(())
 681                    })
 682                    .detach_and_log_err(cx);
 683                },
 684            ),
 685            update_required: false,
 686            edit_prediction_model: EditPredictionModel::Zeta2,
 687            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 688            sweep_ai: SweepAi::new(cx),
 689            mercury: Mercury::new(cx),
 690            ollama: Ollama::new(),
 691
 692            data_collection_choice,
 693            reject_predictions_tx: reject_tx,
 694            rated_predictions: Default::default(),
 695            shown_predictions: Default::default(),
 696        };
 697
 698        this
 699    }
 700
 701    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 702        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 703        let format = ZetaFormat::parse(&version_str).ok()?;
 704        let model_id = env::var("ZED_ZETA_MODEL").ok();
 705        Some(Zeta2RawConfig { model_id, format })
 706    }
 707
 708    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 709        self.edit_prediction_model = model;
 710    }
 711
 712    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 713        self.zeta2_raw_config = Some(config);
 714    }
 715
 716    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 717        self.zeta2_raw_config.as_ref()
 718    }
 719
 720    pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
 721        use ui::IconName;
 722        match self.edit_prediction_model {
 723            EditPredictionModel::Sweep => {
 724                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 725                    .with_disabled(IconName::SweepAiDisabled)
 726                    .with_up(IconName::SweepAiUp)
 727                    .with_down(IconName::SweepAiDown)
 728                    .with_error(IconName::SweepAiError)
 729            }
 730            EditPredictionModel::Mercury => {
 731                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 732            }
 733            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 734                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 735                    .with_disabled(IconName::ZedPredictDisabled)
 736                    .with_up(IconName::ZedPredictUp)
 737                    .with_down(IconName::ZedPredictDown)
 738                    .with_error(IconName::ZedPredictError)
 739            }
 740            EditPredictionModel::Ollama => {
 741                edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 742            }
 743        }
 744    }
 745
 746    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 747        self.sweep_ai.api_token.read(cx).has_key()
 748    }
 749
 750    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 751        self.mercury.api_token.read(cx).has_key()
 752    }
 753
 754    pub fn clear_history(&mut self) {
 755        for project_state in self.projects.values_mut() {
 756            project_state.events.clear();
 757            project_state.last_event.take();
 758        }
 759    }
 760
 761    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 762        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 763            project_state.events.clear();
 764            project_state.last_event.take();
 765        }
 766    }
 767
 768    pub fn edit_history_for_project(
 769        &self,
 770        project: &Entity<Project>,
 771        cx: &App,
 772    ) -> Vec<StoredEvent> {
 773        self.projects
 774            .get(&project.entity_id())
 775            .map(|project_state| project_state.events(cx))
 776            .unwrap_or_default()
 777    }
 778
 779    pub fn context_for_project<'a>(
 780        &'a self,
 781        project: &Entity<Project>,
 782        cx: &'a mut App,
 783    ) -> Vec<RelatedFile> {
 784        self.projects
 785            .get(&project.entity_id())
 786            .map(|project_state| {
 787                project_state.context.update(cx, |context, cx| {
 788                    context
 789                        .related_files_with_buffers(cx)
 790                        .map(|(mut related_file, buffer)| {
 791                            related_file.in_open_source_repo = buffer
 792                                .read(cx)
 793                                .file()
 794                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 795                            related_file
 796                        })
 797                        .collect()
 798                })
 799            })
 800            .unwrap_or_default()
 801    }
 802
 803    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 804        self.projects
 805            .get(&project.entity_id())
 806            .and_then(|project| project.copilot.clone())
 807    }
 808
 809    pub fn start_copilot_for_project(
 810        &mut self,
 811        project: &Entity<Project>,
 812        cx: &mut Context<Self>,
 813    ) -> Option<Entity<Copilot>> {
 814        if DisableAiSettings::get(None, cx).disable_ai {
 815            return None;
 816        }
 817        let state = self.get_or_init_project(project, cx);
 818
 819        if state.copilot.is_some() {
 820            return state.copilot.clone();
 821        }
 822        let _project = project.clone();
 823        let project = project.read(cx);
 824
 825        let node = project.node_runtime().cloned();
 826        if let Some(node) = node {
 827            let next_id = project.languages().next_language_server_id();
 828            let fs = project.fs().clone();
 829
 830            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 831            state.copilot = Some(copilot.clone());
 832            Some(copilot)
 833        } else {
 834            None
 835        }
 836    }
 837
 838    pub fn context_for_project_with_buffers<'a>(
 839        &'a self,
 840        project: &Entity<Project>,
 841        cx: &'a mut App,
 842    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 843        self.projects
 844            .get(&project.entity_id())
 845            .map(|project| {
 846                project.context.update(cx, |context, cx| {
 847                    context.related_files_with_buffers(cx).collect()
 848                })
 849            })
 850            .unwrap_or_default()
 851    }
 852
 853    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 854        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
 855            self.user_store.read(cx).edit_prediction_usage()
 856        } else {
 857            None
 858        }
 859    }
 860
 861    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 862        self.get_or_init_project(project, cx);
 863    }
 864
 865    pub fn register_buffer(
 866        &mut self,
 867        buffer: &Entity<Buffer>,
 868        project: &Entity<Project>,
 869        cx: &mut Context<Self>,
 870    ) {
 871        let project_state = self.get_or_init_project(project, cx);
 872        Self::register_buffer_impl(project_state, buffer, project, cx);
 873    }
 874
 875    fn get_or_init_project(
 876        &mut self,
 877        project: &Entity<Project>,
 878        cx: &mut Context<Self>,
 879    ) -> &mut ProjectState {
 880        let entity_id = project.entity_id();
 881        self.projects
 882            .entry(entity_id)
 883            .or_insert_with(|| ProjectState {
 884                context: {
 885                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 886                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 887                        this.handle_excerpt_store_event(entity_id, event);
 888                    })
 889                    .detach();
 890                    related_excerpt_store
 891                },
 892                events: VecDeque::new(),
 893                last_event: None,
 894                recent_paths: VecDeque::new(),
 895                debug_tx: None,
 896                registered_buffers: HashMap::default(),
 897                current_prediction: None,
 898                cancelled_predictions: HashSet::default(),
 899                pending_predictions: ArrayVec::new(),
 900                next_pending_prediction_id: 0,
 901                last_prediction_refresh: None,
 902                license_detection_watchers: HashMap::default(),
 903                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 904                _subscriptions: [
 905                    cx.subscribe(&project, Self::handle_project_event),
 906                    cx.observe_release(&project, move |this, _, cx| {
 907                        this.projects.remove(&entity_id);
 908                        cx.notify();
 909                    }),
 910                ],
 911                copilot: None,
 912            })
 913    }
 914
 915    pub fn remove_project(&mut self, project: &Entity<Project>) {
 916        self.projects.remove(&project.entity_id());
 917    }
 918
 919    fn handle_excerpt_store_event(
 920        &mut self,
 921        project_entity_id: EntityId,
 922        event: &RelatedExcerptStoreEvent,
 923    ) {
 924        if let Some(project_state) = self.projects.get(&project_entity_id) {
 925            if let Some(debug_tx) = project_state.debug_tx.clone() {
 926                match event {
 927                    RelatedExcerptStoreEvent::StartedRefresh => {
 928                        debug_tx
 929                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 930                                ContextRetrievalStartedDebugEvent {
 931                                    project_entity_id: project_entity_id,
 932                                    timestamp: Instant::now(),
 933                                    search_prompt: String::new(),
 934                                },
 935                            ))
 936                            .ok();
 937                    }
 938                    RelatedExcerptStoreEvent::FinishedRefresh {
 939                        cache_hit_count,
 940                        cache_miss_count,
 941                        mean_definition_latency,
 942                        max_definition_latency,
 943                    } => {
 944                        debug_tx
 945                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 946                                ContextRetrievalFinishedDebugEvent {
 947                                    project_entity_id: project_entity_id,
 948                                    timestamp: Instant::now(),
 949                                    metadata: vec![
 950                                        (
 951                                            "Cache Hits",
 952                                            format!(
 953                                                "{}/{}",
 954                                                cache_hit_count,
 955                                                cache_hit_count + cache_miss_count
 956                                            )
 957                                            .into(),
 958                                        ),
 959                                        (
 960                                            "Max LSP Time",
 961                                            format!("{} ms", max_definition_latency.as_millis())
 962                                                .into(),
 963                                        ),
 964                                        (
 965                                            "Mean LSP Time",
 966                                            format!("{} ms", mean_definition_latency.as_millis())
 967                                                .into(),
 968                                        ),
 969                                    ],
 970                                },
 971                            ))
 972                            .ok();
 973                    }
 974                }
 975            }
 976        }
 977    }
 978
 979    pub fn debug_info(
 980        &mut self,
 981        project: &Entity<Project>,
 982        cx: &mut Context<Self>,
 983    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 984        let project_state = self.get_or_init_project(project, cx);
 985        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 986        project_state.debug_tx = Some(debug_watch_tx);
 987        debug_watch_rx
 988    }
 989
 990    fn handle_project_event(
 991        &mut self,
 992        project: Entity<Project>,
 993        event: &project::Event,
 994        cx: &mut Context<Self>,
 995    ) {
 996        // TODO [zeta2] init with recent paths
 997        match event {
 998            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 999                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1000                    return;
1001                };
1002                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1003                if let Some(path) = path {
1004                    if let Some(ix) = project_state
1005                        .recent_paths
1006                        .iter()
1007                        .position(|probe| probe == &path)
1008                    {
1009                        project_state.recent_paths.remove(ix);
1010                    }
1011                    project_state.recent_paths.push_front(path);
1012                }
1013            }
1014            project::Event::DiagnosticsUpdated { .. } => {
1015                if cx.has_flag::<Zeta2FeatureFlag>() {
1016                    self.refresh_prediction_from_diagnostics(project, cx);
1017                }
1018            }
1019            _ => (),
1020        }
1021    }
1022
1023    fn register_buffer_impl<'a>(
1024        project_state: &'a mut ProjectState,
1025        buffer: &Entity<Buffer>,
1026        project: &Entity<Project>,
1027        cx: &mut Context<Self>,
1028    ) -> &'a mut RegisteredBuffer {
1029        let buffer_id = buffer.entity_id();
1030
1031        if let Some(file) = buffer.read(cx).file() {
1032            let worktree_id = file.worktree_id(cx);
1033            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1034                project_state
1035                    .license_detection_watchers
1036                    .entry(worktree_id)
1037                    .or_insert_with(|| {
1038                        let project_entity_id = project.entity_id();
1039                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1040                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1041                            else {
1042                                return;
1043                            };
1044                            project_state
1045                                .license_detection_watchers
1046                                .remove(&worktree_id);
1047                        })
1048                        .detach();
1049                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1050                    });
1051            }
1052        }
1053
1054        match project_state.registered_buffers.entry(buffer_id) {
1055            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1056            hash_map::Entry::Vacant(entry) => {
1057                let buf = buffer.read(cx);
1058                let snapshot = buf.text_snapshot();
1059                let file = buf.file().cloned();
1060                let project_entity_id = project.entity_id();
1061                entry.insert(RegisteredBuffer {
1062                    snapshot,
1063                    file,
1064                    last_position: None,
1065                    _subscriptions: [
1066                        cx.subscribe(buffer, {
1067                            let project = project.downgrade();
1068                            move |this, buffer, event, cx| {
1069                                if let language::BufferEvent::Edited = event
1070                                    && let Some(project) = project.upgrade()
1071                                {
1072                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1073                                }
1074                            }
1075                        }),
1076                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1077                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1078                            else {
1079                                return;
1080                            };
1081                            project_state.registered_buffers.remove(&buffer_id);
1082                        }),
1083                    ],
1084                })
1085            }
1086        }
1087    }
1088
1089    fn report_changes_for_buffer(
1090        &mut self,
1091        buffer: &Entity<Buffer>,
1092        project: &Entity<Project>,
1093        is_predicted: bool,
1094        cx: &mut Context<Self>,
1095    ) {
1096        let project_state = self.get_or_init_project(project, cx);
1097        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1098
1099        let buf = buffer.read(cx);
1100        let new_file = buf.file().cloned();
1101        let new_snapshot = buf.text_snapshot();
1102        if new_snapshot.version == registered_buffer.snapshot.version {
1103            return;
1104        }
1105
1106        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1107        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1108        let mut num_edits = 0usize;
1109        let mut total_deleted = 0usize;
1110        let mut total_inserted = 0usize;
1111        let mut edit_range: Option<Range<Anchor>> = None;
1112        let mut last_offset: Option<usize> = None;
1113
1114        for (edit, anchor_range) in
1115            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1116        {
1117            num_edits += 1;
1118            total_deleted += edit.old.len();
1119            total_inserted += edit.new.len();
1120            edit_range = Some(match edit_range {
1121                None => anchor_range,
1122                Some(acc) => acc.start..anchor_range.end,
1123            });
1124            last_offset = Some(edit.new.end);
1125        }
1126
1127        let Some(edit_range) = edit_range else {
1128            return;
1129        };
1130
1131        let action_type = match (total_deleted, total_inserted, num_edits) {
1132            (0, ins, n) if ins == n => UserActionType::InsertChar,
1133            (0, _, _) => UserActionType::InsertSelection,
1134            (del, 0, n) if del == n => UserActionType::DeleteChar,
1135            (_, 0, _) => UserActionType::DeleteSelection,
1136            (_, ins, n) if ins == n => UserActionType::InsertChar,
1137            (_, _, _) => UserActionType::InsertSelection,
1138        };
1139
1140        if let Some(offset) = last_offset {
1141            let point = new_snapshot.offset_to_point(offset);
1142            let timestamp_epoch_ms = SystemTime::now()
1143                .duration_since(UNIX_EPOCH)
1144                .map(|d| d.as_millis() as u64)
1145                .unwrap_or(0);
1146            project_state.record_user_action(UserActionRecord {
1147                action_type,
1148                buffer_id: buffer.entity_id(),
1149                line_number: point.row,
1150                offset,
1151                timestamp_epoch_ms,
1152            });
1153        }
1154
1155        let events = &mut project_state.events;
1156
1157        let now = cx.background_executor().now();
1158        if let Some(last_event) = project_state.last_event.as_mut() {
1159            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1160                == last_event.new_snapshot.remote_id()
1161                && old_snapshot.version == last_event.new_snapshot.version;
1162
1163            let prediction_source_changed = is_predicted != last_event.predicted;
1164
1165            let should_coalesce = is_next_snapshot_of_same_buffer
1166                && !prediction_source_changed
1167                && last_event
1168                    .edit_range
1169                    .as_ref()
1170                    .is_some_and(|last_edit_range| {
1171                        lines_between_ranges(
1172                            &edit_range.to_point(&new_snapshot),
1173                            &last_edit_range.to_point(&new_snapshot),
1174                        ) <= CHANGE_GROUPING_LINE_SPAN
1175                    });
1176
1177            if should_coalesce {
1178                let pause_elapsed = last_event
1179                    .last_edit_time
1180                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1181                    .unwrap_or(false);
1182                if pause_elapsed {
1183                    last_event.snapshot_after_last_editing_pause =
1184                        Some(last_event.new_snapshot.clone());
1185                }
1186
1187                last_event.edit_range = Some(edit_range);
1188                last_event.new_snapshot = new_snapshot;
1189                last_event.last_edit_time = Some(now);
1190                return;
1191            }
1192        }
1193
1194        if let Some(event) = project_state.last_event.take() {
1195            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1196                if events.len() + 1 >= EVENT_COUNT_MAX {
1197                    events.pop_front();
1198                }
1199                events.push_back(event);
1200            }
1201        }
1202
1203        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1204
1205        project_state.last_event = Some(LastEvent {
1206            old_file,
1207            new_file,
1208            old_snapshot,
1209            new_snapshot,
1210            edit_range: Some(edit_range),
1211            predicted: is_predicted,
1212            snapshot_after_last_editing_pause: None,
1213            last_edit_time: Some(now),
1214        });
1215    }
1216
1217    fn prediction_at(
1218        &mut self,
1219        buffer: &Entity<Buffer>,
1220        position: Option<language::Anchor>,
1221        project: &Entity<Project>,
1222        cx: &App,
1223    ) -> Option<BufferEditPrediction<'_>> {
1224        let project_state = self.projects.get_mut(&project.entity_id())?;
1225        if let Some(position) = position
1226            && let Some(buffer) = project_state
1227                .registered_buffers
1228                .get_mut(&buffer.entity_id())
1229        {
1230            buffer.last_position = Some(position);
1231        }
1232
1233        let CurrentEditPrediction {
1234            requested_by,
1235            prediction,
1236            ..
1237        } = project_state.current_prediction.as_ref()?;
1238
1239        if prediction.targets_buffer(buffer.read(cx)) {
1240            Some(BufferEditPrediction::Local { prediction })
1241        } else {
1242            let show_jump = match requested_by {
1243                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1244                    requested_by_buffer_id == &buffer.entity_id()
1245                }
1246                PredictionRequestedBy::DiagnosticsUpdate => true,
1247            };
1248
1249            if show_jump {
1250                Some(BufferEditPrediction::Jump { prediction })
1251            } else {
1252                None
1253            }
1254        }
1255    }
1256
1257    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1258        let Some(current_prediction) = self
1259            .projects
1260            .get_mut(&project.entity_id())
1261            .and_then(|project_state| project_state.current_prediction.take())
1262        else {
1263            return;
1264        };
1265
1266        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1267
1268        // can't hold &mut project_state ref across report_changes_for_buffer_call
1269        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1270            return;
1271        };
1272
1273        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1274            project_state.cancel_pending_prediction(pending_prediction, cx);
1275        }
1276
1277        match self.edit_prediction_model {
1278            EditPredictionModel::Sweep => {
1279                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1280            }
1281            EditPredictionModel::Mercury => {
1282                mercury::edit_prediction_accepted(
1283                    current_prediction.prediction.id,
1284                    self.client.http_client(),
1285                    cx,
1286                );
1287            }
1288            EditPredictionModel::Ollama => {}
1289            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1290                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1291            }
1292        }
1293    }
1294
1295    async fn handle_rejected_predictions(
1296        rx: UnboundedReceiver<EditPredictionRejection>,
1297        client: Arc<Client>,
1298        llm_token: LlmApiToken,
1299        app_version: Version,
1300        background_executor: BackgroundExecutor,
1301    ) {
1302        let mut rx = std::pin::pin!(rx.peekable());
1303        let mut batched = Vec::new();
1304
1305        while let Some(rejection) = rx.next().await {
1306            batched.push(rejection);
1307
1308            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1309                select_biased! {
1310                    next = rx.as_mut().peek().fuse() => {
1311                        if next.is_some() {
1312                            continue;
1313                        }
1314                    }
1315                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1316                }
1317            }
1318
1319            let url = client
1320                .http_client()
1321                .build_zed_llm_url("/predict_edits/reject", &[])
1322                .unwrap();
1323
1324            let flush_count = batched
1325                .len()
1326                // in case items have accumulated after failure
1327                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1328            let start = batched.len() - flush_count;
1329
1330            let body = RejectEditPredictionsBodyRef {
1331                rejections: &batched[start..],
1332            };
1333
1334            let result = Self::send_api_request::<()>(
1335                |builder| {
1336                    let req = builder
1337                        .uri(url.as_ref())
1338                        .body(serde_json::to_string(&body)?.into());
1339                    anyhow::Ok(req?)
1340                },
1341                client.clone(),
1342                llm_token.clone(),
1343                app_version.clone(),
1344                true,
1345            )
1346            .await;
1347
1348            if result.log_err().is_some() {
1349                batched.drain(start..);
1350            }
1351        }
1352    }
1353
1354    fn reject_current_prediction(
1355        &mut self,
1356        reason: EditPredictionRejectReason,
1357        project: &Entity<Project>,
1358        cx: &App,
1359    ) {
1360        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1361            project_state.pending_predictions.clear();
1362            if let Some(prediction) = project_state.current_prediction.take() {
1363                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1364            }
1365        };
1366    }
1367
1368    fn did_show_current_prediction(
1369        &mut self,
1370        project: &Entity<Project>,
1371        display_type: edit_prediction_types::SuggestionDisplayType,
1372        cx: &mut Context<Self>,
1373    ) {
1374        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1375            return;
1376        };
1377
1378        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1379            return;
1380        };
1381
1382        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1383        let previous_shown_with = current_prediction.shown_with;
1384
1385        if previous_shown_with.is_none() || !is_jump {
1386            current_prediction.shown_with = Some(display_type);
1387        }
1388
1389        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1390
1391        if is_first_non_jump_show {
1392            current_prediction.was_shown = true;
1393        }
1394
1395        let display_type_changed = previous_shown_with != Some(display_type);
1396
1397        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1398            sweep_ai::edit_prediction_shown(
1399                &self.sweep_ai,
1400                self.client.clone(),
1401                &current_prediction.prediction,
1402                display_type,
1403                cx,
1404            );
1405        }
1406
1407        if is_first_non_jump_show {
1408            self.shown_predictions
1409                .push_front(current_prediction.prediction.clone());
1410            if self.shown_predictions.len() > 50 {
1411                let completion = self.shown_predictions.pop_back().unwrap();
1412                self.rated_predictions.remove(&completion.id);
1413            }
1414        }
1415    }
1416
1417    fn reject_prediction(
1418        &mut self,
1419        prediction_id: EditPredictionId,
1420        reason: EditPredictionRejectReason,
1421        was_shown: bool,
1422        cx: &App,
1423    ) {
1424        match self.edit_prediction_model {
1425            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1426                self.reject_predictions_tx
1427                    .unbounded_send(EditPredictionRejection {
1428                        request_id: prediction_id.to_string(),
1429                        reason,
1430                        was_shown,
1431                    })
1432                    .log_err();
1433            }
1434            EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1435            EditPredictionModel::Mercury => {
1436                mercury::edit_prediction_rejected(
1437                    prediction_id,
1438                    was_shown,
1439                    reason,
1440                    self.client.http_client(),
1441                    cx,
1442                );
1443            }
1444        }
1445    }
1446
1447    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1448        self.projects
1449            .get(&project.entity_id())
1450            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1451    }
1452
1453    pub fn refresh_prediction_from_buffer(
1454        &mut self,
1455        project: Entity<Project>,
1456        buffer: Entity<Buffer>,
1457        position: language::Anchor,
1458        cx: &mut Context<Self>,
1459    ) {
1460        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1461            let Some(request_task) = this
1462                .update(cx, |this, cx| {
1463                    this.request_prediction(
1464                        &project,
1465                        &buffer,
1466                        position,
1467                        PredictEditsRequestTrigger::Other,
1468                        cx,
1469                    )
1470                })
1471                .log_err()
1472            else {
1473                return Task::ready(anyhow::Ok(None));
1474            };
1475
1476            cx.spawn(async move |_cx| {
1477                request_task.await.map(|prediction_result| {
1478                    prediction_result.map(|prediction_result| {
1479                        (
1480                            prediction_result,
1481                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1482                        )
1483                    })
1484                })
1485            })
1486        })
1487    }
1488
1489    pub fn refresh_prediction_from_diagnostics(
1490        &mut self,
1491        project: Entity<Project>,
1492        cx: &mut Context<Self>,
1493    ) {
1494        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1495            return;
1496        };
1497
1498        // Prefer predictions from buffer
1499        if project_state.current_prediction.is_some() {
1500            return;
1501        };
1502
1503        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1504            let Some((active_buffer, snapshot, cursor_point)) = this
1505                .read_with(cx, |this, cx| {
1506                    let project_state = this.projects.get(&project.entity_id())?;
1507                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1508                    let snapshot = buffer.read(cx).snapshot();
1509
1510                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1511                        return None;
1512                    }
1513
1514                    let cursor_point = position
1515                        .map(|pos| pos.to_point(&snapshot))
1516                        .unwrap_or_default();
1517
1518                    Some((buffer, snapshot, cursor_point))
1519                })
1520                .log_err()
1521                .flatten()
1522            else {
1523                return Task::ready(anyhow::Ok(None));
1524            };
1525
1526            cx.spawn(async move |cx| {
1527                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1528                    active_buffer,
1529                    &snapshot,
1530                    Default::default(),
1531                    cursor_point,
1532                    &project,
1533                    cx,
1534                )
1535                .await?
1536                else {
1537                    return anyhow::Ok(None);
1538                };
1539
1540                let Some(prediction_result) = this
1541                    .update(cx, |this, cx| {
1542                        this.request_prediction(
1543                            &project,
1544                            &jump_buffer,
1545                            jump_position,
1546                            PredictEditsRequestTrigger::Diagnostics,
1547                            cx,
1548                        )
1549                    })?
1550                    .await?
1551                else {
1552                    return anyhow::Ok(None);
1553                };
1554
1555                this.update(cx, |this, cx| {
1556                    Some((
1557                        if this
1558                            .get_or_init_project(&project, cx)
1559                            .current_prediction
1560                            .is_none()
1561                        {
1562                            prediction_result
1563                        } else {
1564                            EditPredictionResult {
1565                                id: prediction_result.id,
1566                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1567                            }
1568                        },
1569                        PredictionRequestedBy::DiagnosticsUpdate,
1570                    ))
1571                })
1572            })
1573        });
1574    }
1575
1576    fn predictions_enabled_at(
1577        snapshot: &BufferSnapshot,
1578        position: Option<language::Anchor>,
1579        cx: &App,
1580    ) -> bool {
1581        let file = snapshot.file();
1582        let all_settings = all_language_settings(file, cx);
1583        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1584            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1585        {
1586            return false;
1587        }
1588
1589        if let Some(last_position) = position {
1590            let settings = snapshot.settings_at(last_position, cx);
1591
1592            if !settings.edit_predictions_disabled_in.is_empty()
1593                && let Some(scope) = snapshot.language_scope_at(last_position)
1594                && let Some(scope_name) = scope.override_name()
1595                && settings
1596                    .edit_predictions_disabled_in
1597                    .iter()
1598                    .any(|s| s == scope_name)
1599            {
1600                return false;
1601            }
1602        }
1603
1604        true
1605    }
1606
1607    #[cfg(not(test))]
1608    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1609    #[cfg(test)]
1610    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1611
1612    fn queue_prediction_refresh(
1613        &mut self,
1614        project: Entity<Project>,
1615        throttle_entity: EntityId,
1616        cx: &mut Context<Self>,
1617        do_refresh: impl FnOnce(
1618            WeakEntity<Self>,
1619            &mut AsyncApp,
1620        )
1621            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1622        + 'static,
1623    ) {
1624        let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1625        let drop_on_cancel = is_ollama;
1626        let max_pending_predictions = if is_ollama { 1 } else { 2 };
1627        let project_state = self.get_or_init_project(&project, cx);
1628        let pending_prediction_id = project_state.next_pending_prediction_id;
1629        project_state.next_pending_prediction_id += 1;
1630        let last_request = project_state.last_prediction_refresh;
1631
1632        let task = cx.spawn(async move |this, cx| {
1633            if let Some((last_entity, last_timestamp)) = last_request
1634                && throttle_entity == last_entity
1635                && let Some(timeout) =
1636                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1637            {
1638                cx.background_executor().timer(timeout).await;
1639            }
1640
1641            // If this task was cancelled before the throttle timeout expired,
1642            // do not perform a request.
1643            let mut is_cancelled = true;
1644            this.update(cx, |this, cx| {
1645                let project_state = this.get_or_init_project(&project, cx);
1646                if !project_state
1647                    .cancelled_predictions
1648                    .remove(&pending_prediction_id)
1649                {
1650                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1651                    is_cancelled = false;
1652                }
1653            })
1654            .ok();
1655            if is_cancelled {
1656                return None;
1657            }
1658
1659            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1660            let new_prediction_id = new_prediction_result
1661                .as_ref()
1662                .map(|(prediction, _)| prediction.id.clone());
1663
1664            // When a prediction completes, remove it from the pending list, and cancel
1665            // any pending predictions that were enqueued before it.
1666            this.update(cx, |this, cx| {
1667                let project_state = this.get_or_init_project(&project, cx);
1668
1669                let is_cancelled = project_state
1670                    .cancelled_predictions
1671                    .remove(&pending_prediction_id);
1672
1673                let new_current_prediction = if !is_cancelled
1674                    && let Some((prediction_result, requested_by)) = new_prediction_result
1675                {
1676                    match prediction_result.prediction {
1677                        Ok(prediction) => {
1678                            let new_prediction = CurrentEditPrediction {
1679                                requested_by,
1680                                prediction,
1681                                was_shown: false,
1682                                shown_with: None,
1683                            };
1684
1685                            if let Some(current_prediction) =
1686                                project_state.current_prediction.as_ref()
1687                            {
1688                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1689                                {
1690                                    this.reject_current_prediction(
1691                                        EditPredictionRejectReason::Replaced,
1692                                        &project,
1693                                        cx,
1694                                    );
1695
1696                                    Some(new_prediction)
1697                                } else {
1698                                    this.reject_prediction(
1699                                        new_prediction.prediction.id,
1700                                        EditPredictionRejectReason::CurrentPreferred,
1701                                        false,
1702                                        cx,
1703                                    );
1704                                    None
1705                                }
1706                            } else {
1707                                Some(new_prediction)
1708                            }
1709                        }
1710                        Err(reject_reason) => {
1711                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1712                            None
1713                        }
1714                    }
1715                } else {
1716                    None
1717                };
1718
1719                let project_state = this.get_or_init_project(&project, cx);
1720
1721                if let Some(new_prediction) = new_current_prediction {
1722                    project_state.current_prediction = Some(new_prediction);
1723                }
1724
1725                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1726                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1727                    if pending_prediction.id == pending_prediction_id {
1728                        pending_predictions.remove(ix);
1729                        for pending_prediction in pending_predictions.drain(0..ix) {
1730                            project_state.cancel_pending_prediction(pending_prediction, cx)
1731                        }
1732                        break;
1733                    }
1734                }
1735                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1736                cx.notify();
1737            })
1738            .ok();
1739
1740            new_prediction_id
1741        });
1742
1743        if project_state.pending_predictions.len() < max_pending_predictions {
1744            project_state.pending_predictions.push(PendingPrediction {
1745                id: pending_prediction_id,
1746                task,
1747                drop_on_cancel,
1748            });
1749        } else {
1750            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1751            project_state.pending_predictions.push(PendingPrediction {
1752                id: pending_prediction_id,
1753                task,
1754                drop_on_cancel,
1755            });
1756            project_state.cancel_pending_prediction(pending_prediction, cx);
1757        }
1758    }
1759
1760    pub fn request_prediction(
1761        &mut self,
1762        project: &Entity<Project>,
1763        active_buffer: &Entity<Buffer>,
1764        position: language::Anchor,
1765        trigger: PredictEditsRequestTrigger,
1766        cx: &mut Context<Self>,
1767    ) -> Task<Result<Option<EditPredictionResult>>> {
1768        self.request_prediction_internal(
1769            project.clone(),
1770            active_buffer.clone(),
1771            position,
1772            trigger,
1773            cx.has_flag::<Zeta2FeatureFlag>(),
1774            cx,
1775        )
1776    }
1777
1778    fn request_prediction_internal(
1779        &mut self,
1780        project: Entity<Project>,
1781        active_buffer: Entity<Buffer>,
1782        position: language::Anchor,
1783        trigger: PredictEditsRequestTrigger,
1784        allow_jump: bool,
1785        cx: &mut Context<Self>,
1786    ) -> Task<Result<Option<EditPredictionResult>>> {
1787        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1788
1789        self.get_or_init_project(&project, cx);
1790        let project_state = self.projects.get(&project.entity_id()).unwrap();
1791        let stored_events = project_state.events(cx);
1792        let has_events = !stored_events.is_empty();
1793        let events: Vec<Arc<zeta_prompt::Event>> =
1794            stored_events.into_iter().map(|e| e.event).collect();
1795        let debug_tx = project_state.debug_tx.clone();
1796
1797        let snapshot = active_buffer.read(cx).snapshot();
1798        let cursor_point = position.to_point(&snapshot);
1799        let current_offset = position.to_offset(&snapshot);
1800
1801        let mut user_actions: Vec<UserActionRecord> =
1802            project_state.user_actions.iter().cloned().collect();
1803
1804        if let Some(last_action) = user_actions.last() {
1805            if last_action.buffer_id == active_buffer.entity_id()
1806                && current_offset != last_action.offset
1807            {
1808                let timestamp_epoch_ms = SystemTime::now()
1809                    .duration_since(UNIX_EPOCH)
1810                    .map(|d| d.as_millis() as u64)
1811                    .unwrap_or(0);
1812                user_actions.push(UserActionRecord {
1813                    action_type: UserActionType::CursorMovement,
1814                    buffer_id: active_buffer.entity_id(),
1815                    line_number: cursor_point.row,
1816                    offset: current_offset,
1817                    timestamp_epoch_ms,
1818                });
1819            }
1820        }
1821        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1822        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1823        let diagnostic_search_range =
1824            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1825
1826        let related_files = self.context_for_project(&project, cx);
1827
1828        let inputs = EditPredictionModelInput {
1829            project: project.clone(),
1830            buffer: active_buffer.clone(),
1831            snapshot: snapshot.clone(),
1832            position,
1833            events,
1834            related_files,
1835            recent_paths: project_state.recent_paths.clone(),
1836            trigger,
1837            diagnostic_search_range: diagnostic_search_range.clone(),
1838            debug_tx,
1839            user_actions,
1840        };
1841
1842        let task = match &self.edit_prediction_model {
1843            EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
1844                self,
1845                inputs,
1846                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1847                cx,
1848            ),
1849            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1850                self,
1851                inputs,
1852                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1853                cx,
1854            ),
1855            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1856            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1857            EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1858        };
1859
1860        cx.spawn(async move |this, cx| {
1861            let prediction = task.await?;
1862
1863            if prediction.is_none() && allow_jump {
1864                let cursor_point = position.to_point(&snapshot);
1865                if has_events
1866                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1867                        active_buffer.clone(),
1868                        &snapshot,
1869                        diagnostic_search_range,
1870                        cursor_point,
1871                        &project,
1872                        cx,
1873                    )
1874                    .await?
1875                {
1876                    return this
1877                        .update(cx, |this, cx| {
1878                            this.request_prediction_internal(
1879                                project,
1880                                jump_buffer,
1881                                jump_position,
1882                                trigger,
1883                                false,
1884                                cx,
1885                            )
1886                        })?
1887                        .await;
1888                }
1889
1890                return anyhow::Ok(None);
1891            }
1892
1893            Ok(prediction)
1894        })
1895    }
1896
1897    async fn next_diagnostic_location(
1898        active_buffer: Entity<Buffer>,
1899        active_buffer_snapshot: &BufferSnapshot,
1900        active_buffer_diagnostic_search_range: Range<Point>,
1901        active_buffer_cursor_point: Point,
1902        project: &Entity<Project>,
1903        cx: &mut AsyncApp,
1904    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1905        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1906        let mut jump_location = active_buffer_snapshot
1907            .diagnostic_groups(None)
1908            .into_iter()
1909            .filter_map(|(_, group)| {
1910                let range = &group.entries[group.primary_ix]
1911                    .range
1912                    .to_point(&active_buffer_snapshot);
1913                if range.overlaps(&active_buffer_diagnostic_search_range) {
1914                    None
1915                } else {
1916                    Some(range.start)
1917                }
1918            })
1919            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1920            .map(|position| {
1921                (
1922                    active_buffer.clone(),
1923                    active_buffer_snapshot.anchor_before(position),
1924                )
1925            });
1926
1927        if jump_location.is_none() {
1928            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1929                let file = buffer.file()?;
1930
1931                Some(ProjectPath {
1932                    worktree_id: file.worktree_id(cx),
1933                    path: file.path().clone(),
1934                })
1935            });
1936
1937            let buffer_task = project.update(cx, |project, cx| {
1938                let (path, _, _) = project
1939                    .diagnostic_summaries(false, cx)
1940                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1941                    .max_by_key(|(path, _, _)| {
1942                        // find the buffer with errors that shares most parent directories
1943                        path.path
1944                            .components()
1945                            .zip(
1946                                active_buffer_path
1947                                    .as_ref()
1948                                    .map(|p| p.path.components())
1949                                    .unwrap_or_default(),
1950                            )
1951                            .take_while(|(a, b)| a == b)
1952                            .count()
1953                    })?;
1954
1955                Some(project.open_buffer(path, cx))
1956            });
1957
1958            if let Some(buffer_task) = buffer_task {
1959                let closest_buffer = buffer_task.await?;
1960
1961                jump_location = closest_buffer
1962                    .read_with(cx, |buffer, _cx| {
1963                        buffer
1964                            .buffer_diagnostics(None)
1965                            .into_iter()
1966                            .min_by_key(|entry| entry.diagnostic.severity)
1967                            .map(|entry| entry.range.start)
1968                    })
1969                    .map(|position| (closest_buffer, position));
1970            }
1971        }
1972
1973        anyhow::Ok(jump_location)
1974    }
1975
1976    async fn send_raw_llm_request(
1977        request: RawCompletionRequest,
1978        client: Arc<Client>,
1979        custom_url: Option<Arc<Url>>,
1980        llm_token: LlmApiToken,
1981        app_version: Version,
1982    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1983        let url = if let Some(custom_url) = custom_url {
1984            custom_url.as_ref().clone()
1985        } else {
1986            client
1987                .http_client()
1988                .build_zed_llm_url("/predict_edits/raw", &[])?
1989        };
1990
1991        Self::send_api_request(
1992            |builder| {
1993                let req = builder
1994                    .uri(url.as_ref())
1995                    .body(serde_json::to_string(&request)?.into());
1996                Ok(req?)
1997            },
1998            client,
1999            llm_token,
2000            app_version,
2001            true,
2002        )
2003        .await
2004    }
2005
2006    pub(crate) async fn send_v3_request(
2007        input: ZetaPromptInput,
2008        client: Arc<Client>,
2009        llm_token: LlmApiToken,
2010        app_version: Version,
2011        trigger: PredictEditsRequestTrigger,
2012    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2013        let url = client
2014            .http_client()
2015            .build_zed_llm_url("/predict_edits/v3", &[])?;
2016
2017        let request = PredictEditsV3Request { input, trigger };
2018
2019        let json_bytes = serde_json::to_vec(&request)?;
2020        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2021
2022        Self::send_api_request(
2023            |builder| {
2024                let req = builder
2025                    .uri(url.as_ref())
2026                    .header("Content-Encoding", "zstd")
2027                    .body(compressed.clone().into());
2028                Ok(req?)
2029            },
2030            client,
2031            llm_token,
2032            app_version,
2033            true,
2034        )
2035        .await
2036    }
2037
2038    fn handle_api_response<T>(
2039        this: &WeakEntity<Self>,
2040        response: Result<(T, Option<EditPredictionUsage>)>,
2041        cx: &mut gpui::AsyncApp,
2042    ) -> Result<T> {
2043        match response {
2044            Ok((data, usage)) => {
2045                if let Some(usage) = usage {
2046                    this.update(cx, |this, cx| {
2047                        this.user_store.update(cx, |user_store, cx| {
2048                            user_store.update_edit_prediction_usage(usage, cx);
2049                        });
2050                    })
2051                    .ok();
2052                }
2053                Ok(data)
2054            }
2055            Err(err) => {
2056                if err.is::<ZedUpdateRequiredError>() {
2057                    cx.update(|cx| {
2058                        this.update(cx, |this, _cx| {
2059                            this.update_required = true;
2060                        })
2061                        .ok();
2062
2063                        let error_message: SharedString = err.to_string().into();
2064                        show_app_notification(
2065                            NotificationId::unique::<ZedUpdateRequiredError>(),
2066                            cx,
2067                            move |cx| {
2068                                cx.new(|cx| {
2069                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2070                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2071                                })
2072                            },
2073                        );
2074                    });
2075                }
2076                Err(err)
2077            }
2078        }
2079    }
2080
2081    async fn send_api_request<Res>(
2082        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2083        client: Arc<Client>,
2084        llm_token: LlmApiToken,
2085        app_version: Version,
2086        require_auth: bool,
2087    ) -> Result<(Res, Option<EditPredictionUsage>)>
2088    where
2089        Res: DeserializeOwned,
2090    {
2091        let http_client = client.http_client();
2092
2093        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2094            Some(custom_token)
2095        } else if require_auth {
2096            Some(llm_token.acquire(&client).await?)
2097        } else {
2098            llm_token.acquire(&client).await.ok()
2099        };
2100        let mut did_retry = false;
2101
2102        loop {
2103            let request_builder = http_client::Request::builder().method(Method::POST);
2104
2105            let mut request_builder = request_builder
2106                .header("Content-Type", "application/json")
2107                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2108
2109            // Only add Authorization header if we have a token
2110            if let Some(ref token_value) = token {
2111                request_builder =
2112                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2113            }
2114
2115            let request = build(request_builder)?;
2116
2117            let mut response = http_client.send(request).await?;
2118
2119            if let Some(minimum_required_version) = response
2120                .headers()
2121                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2122                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2123            {
2124                anyhow::ensure!(
2125                    app_version >= minimum_required_version,
2126                    ZedUpdateRequiredError {
2127                        minimum_version: minimum_required_version
2128                    }
2129                );
2130            }
2131
2132            if response.status().is_success() {
2133                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2134
2135                let mut body = Vec::new();
2136                response.body_mut().read_to_end(&mut body).await?;
2137                return Ok((serde_json::from_slice(&body)?, usage));
2138            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2139                did_retry = true;
2140                token = Some(llm_token.refresh(&client).await?);
2141            } else {
2142                let mut body = String::new();
2143                response.body_mut().read_to_string(&mut body).await?;
2144                anyhow::bail!(
2145                    "Request failed with status: {:?}\nBody: {}",
2146                    response.status(),
2147                    body
2148                );
2149            }
2150        }
2151    }
2152
2153    pub fn refresh_context(
2154        &mut self,
2155        project: &Entity<Project>,
2156        buffer: &Entity<language::Buffer>,
2157        cursor_position: language::Anchor,
2158        cx: &mut Context<Self>,
2159    ) {
2160        self.get_or_init_project(project, cx)
2161            .context
2162            .update(cx, |store, cx| {
2163                store.refresh(buffer.clone(), cursor_position, cx);
2164            });
2165    }
2166
2167    #[cfg(feature = "cli-support")]
2168    pub fn set_context_for_buffer(
2169        &mut self,
2170        project: &Entity<Project>,
2171        related_files: Vec<RelatedFile>,
2172        cx: &mut Context<Self>,
2173    ) {
2174        self.get_or_init_project(project, cx)
2175            .context
2176            .update(cx, |store, cx| {
2177                store.set_related_files(related_files, cx);
2178            });
2179    }
2180
2181    #[cfg(feature = "cli-support")]
2182    pub fn set_recent_paths_for_project(
2183        &mut self,
2184        project: &Entity<Project>,
2185        paths: impl IntoIterator<Item = project::ProjectPath>,
2186        cx: &mut Context<Self>,
2187    ) {
2188        let project_state = self.get_or_init_project(project, cx);
2189        project_state.recent_paths = paths.into_iter().collect();
2190    }
2191
2192    fn is_file_open_source(
2193        &self,
2194        project: &Entity<Project>,
2195        file: &Arc<dyn File>,
2196        cx: &App,
2197    ) -> bool {
2198        if !file.is_local() || file.is_private() {
2199            return false;
2200        }
2201        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2202            return false;
2203        };
2204        project_state
2205            .license_detection_watchers
2206            .get(&file.worktree_id(cx))
2207            .as_ref()
2208            .is_some_and(|watcher| watcher.is_project_open_source())
2209    }
2210
2211    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2212        self.data_collection_choice.is_enabled(cx)
2213    }
2214
2215    fn load_data_collection_choice() -> DataCollectionChoice {
2216        let choice = KEY_VALUE_STORE
2217            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2218            .log_err()
2219            .flatten();
2220
2221        match choice.as_deref() {
2222            Some("true") => DataCollectionChoice::Enabled,
2223            Some("false") => DataCollectionChoice::Disabled,
2224            Some(_) => {
2225                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2226                DataCollectionChoice::NotAnswered
2227            }
2228            None => DataCollectionChoice::NotAnswered,
2229        }
2230    }
2231
2232    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2233        self.data_collection_choice = self.data_collection_choice.toggle();
2234        let new_choice = self.data_collection_choice;
2235        let is_enabled = new_choice.is_enabled(cx);
2236        db::write_and_log(cx, move || {
2237            KEY_VALUE_STORE.write_kvp(
2238                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2239                is_enabled.to_string(),
2240            )
2241        });
2242    }
2243
2244    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2245        self.shown_predictions.iter()
2246    }
2247
2248    pub fn shown_completions_len(&self) -> usize {
2249        self.shown_predictions.len()
2250    }
2251
2252    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2253        self.rated_predictions.contains(id)
2254    }
2255
2256    pub fn rate_prediction(
2257        &mut self,
2258        prediction: &EditPrediction,
2259        rating: EditPredictionRating,
2260        feedback: String,
2261        cx: &mut Context<Self>,
2262    ) {
2263        self.rated_predictions.insert(prediction.id.clone());
2264        telemetry::event!(
2265            "Edit Prediction Rated",
2266            request_id = prediction.id.to_string(),
2267            rating,
2268            inputs = prediction.inputs,
2269            output = prediction
2270                .edit_preview
2271                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2272            feedback
2273        );
2274        self.client.telemetry().flush_events().detach();
2275        cx.notify();
2276    }
2277}
2278
2279fn merge_trailing_events_if_needed(
2280    events: &mut VecDeque<StoredEvent>,
2281    end_snapshot: &TextBufferSnapshot,
2282    latest_snapshot: &TextBufferSnapshot,
2283    latest_edit_range: &Range<Anchor>,
2284) {
2285    if let Some(last_event) = events.back() {
2286        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2287            return;
2288        }
2289    }
2290
2291    let mut next_old_event = None;
2292    let mut mergeable_count = 0;
2293    for old_event in events.iter().rev() {
2294        if let Some(next_old_event) = &next_old_event
2295            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2296        {
2297            break;
2298        }
2299        mergeable_count += 1;
2300        next_old_event = Some(old_event);
2301    }
2302
2303    if mergeable_count <= 1 {
2304        return;
2305    }
2306
2307    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2308    let oldest_event = events_to_merge.peek().unwrap();
2309    let oldest_snapshot = oldest_event.old_snapshot.clone();
2310
2311    if let Some((diff, edited_range)) =
2312        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2313    {
2314        let merged_event = match oldest_event.event.as_ref() {
2315            zeta_prompt::Event::BufferChange {
2316                old_path,
2317                path,
2318                in_open_source_repo,
2319                ..
2320            } => StoredEvent {
2321                event: Arc::new(zeta_prompt::Event::BufferChange {
2322                    old_path: old_path.clone(),
2323                    path: path.clone(),
2324                    diff,
2325                    in_open_source_repo: *in_open_source_repo,
2326                    predicted: events_to_merge.all(|e| {
2327                        matches!(
2328                            e.event.as_ref(),
2329                            zeta_prompt::Event::BufferChange {
2330                                predicted: true,
2331                                ..
2332                            }
2333                        )
2334                    }),
2335                }),
2336                old_snapshot: oldest_snapshot.clone(),
2337                edit_range: end_snapshot.anchor_before(edited_range.start)
2338                    ..end_snapshot.anchor_before(edited_range.end),
2339            },
2340        };
2341        events.truncate(events.len() - mergeable_count);
2342        events.push_back(merged_event);
2343    }
2344}
2345
2346pub(crate) fn filter_redundant_excerpts(
2347    mut related_files: Vec<RelatedFile>,
2348    cursor_path: &Path,
2349    cursor_row_range: Range<u32>,
2350) -> Vec<RelatedFile> {
2351    for file in &mut related_files {
2352        if file.path.as_ref() == cursor_path {
2353            file.excerpts.retain(|excerpt| {
2354                excerpt.row_range.start < cursor_row_range.start
2355                    || excerpt.row_range.end > cursor_row_range.end
2356            });
2357        }
2358    }
2359    related_files.retain(|file| !file.excerpts.is_empty());
2360    related_files
2361}
2362
2363#[derive(Error, Debug)]
2364#[error(
2365    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2366)]
2367pub struct ZedUpdateRequiredError {
2368    minimum_version: Version,
2369}
2370
2371#[derive(Debug, Clone, Copy)]
2372pub enum DataCollectionChoice {
2373    NotAnswered,
2374    Enabled,
2375    Disabled,
2376}
2377
2378impl DataCollectionChoice {
2379    pub fn is_enabled(self, cx: &App) -> bool {
2380        if cx.is_staff() {
2381            return true;
2382        }
2383        match self {
2384            Self::Enabled => true,
2385            Self::NotAnswered | Self::Disabled => false,
2386        }
2387    }
2388
2389    #[must_use]
2390    pub fn toggle(&self) -> DataCollectionChoice {
2391        match self {
2392            Self::Enabled => Self::Disabled,
2393            Self::Disabled => Self::Enabled,
2394            Self::NotAnswered => Self::Enabled,
2395        }
2396    }
2397}
2398
2399impl From<bool> for DataCollectionChoice {
2400    fn from(value: bool) -> Self {
2401        match value {
2402            true => DataCollectionChoice::Enabled,
2403            false => DataCollectionChoice::Disabled,
2404        }
2405    }
2406}
2407
2408struct ZedPredictUpsell;
2409
2410impl Dismissable for ZedPredictUpsell {
2411    const KEY: &'static str = "dismissed-edit-predict-upsell";
2412
2413    fn dismissed() -> bool {
2414        // To make this backwards compatible with older versions of Zed, we
2415        // check if the user has seen the previous Edit Prediction Onboarding
2416        // before, by checking the data collection choice which was written to
2417        // the database once the user clicked on "Accept and Enable"
2418        if KEY_VALUE_STORE
2419            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2420            .log_err()
2421            .is_some_and(|s| s.is_some())
2422        {
2423            return true;
2424        }
2425
2426        KEY_VALUE_STORE
2427            .read_kvp(Self::KEY)
2428            .log_err()
2429            .is_some_and(|s| s.is_some())
2430    }
2431}
2432
2433pub fn should_show_upsell_modal() -> bool {
2434    !ZedPredictUpsell::dismissed()
2435}
2436
2437pub fn init(cx: &mut App) {
2438    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2439        workspace.register_action(
2440            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2441                ZedPredictModal::toggle(
2442                    workspace,
2443                    workspace.user_store().clone(),
2444                    workspace.client().clone(),
2445                    window,
2446                    cx,
2447                )
2448            },
2449        );
2450
2451        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2452            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2453                settings
2454                    .project
2455                    .all_languages
2456                    .edit_predictions
2457                    .get_or_insert_default()
2458                    .provider = Some(EditPredictionProvider::None)
2459            });
2460        });
2461        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2462            EditPredictionStore::try_global(cx).and_then(|store| {
2463                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2464            })
2465        }
2466
2467        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2468            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2469                copilot_ui::initiate_sign_in(copilot, window, cx);
2470            }
2471        });
2472        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2473            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2474                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2475            }
2476        });
2477        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2478            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2479                copilot_ui::initiate_sign_out(copilot, window, cx);
2480            }
2481        });
2482    })
2483    .detach();
2484}