edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use std::env;
  40use text::Edit;
  41use workspace::Workspace;
  42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  43
  44use std::mem;
  45use std::ops::Range;
  46use std::path::Path;
  47use std::rc::Rc;
  48use std::str::FromStr as _;
  49use std::sync::Arc;
  50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59pub mod ollama;
  60mod onboarding_modal;
  61pub mod open_ai_response;
  62mod prediction;
  63pub mod sweep_ai;
  64
  65pub mod udiff;
  66
  67mod capture_example;
  68mod zed_edit_prediction_delegate;
  69pub mod zeta1;
  70pub mod zeta2;
  71
  72#[cfg(test)]
  73mod edit_prediction_tests;
  74
  75use crate::license_detection::LicenseDetectionWatcher;
  76use crate::mercury::Mercury;
  77use crate::ollama::Ollama;
  78use crate::onboarding_modal::ZedPredictModal;
  79pub use crate::prediction::EditPrediction;
  80pub use crate::prediction::EditPredictionId;
  81use crate::prediction::EditPredictionResult;
  82pub use crate::sweep_ai::SweepAi;
  83pub use capture_example::capture_example;
  84pub use language_model::ApiKeyState;
  85pub use telemetry_events::EditPredictionRating;
  86pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  87
  88actions!(
  89    edit_prediction,
  90    [
  91        /// Resets the edit prediction onboarding state.
  92        ResetOnboarding,
  93        /// Clears the edit prediction history.
  94        ClearHistory,
  95    ]
  96);
  97
  98/// Maximum number of events to track.
  99const EVENT_COUNT_MAX: usize = 6;
 100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 101const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 103const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 104
 105pub struct Zeta2FeatureFlag;
 106
 107impl FeatureFlag for Zeta2FeatureFlag {
 108    const NAME: &'static str = "zeta2";
 109
 110    fn enabled_for_staff() -> bool {
 111        true
 112    }
 113}
 114
 115#[derive(Clone)]
 116struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 117
 118impl Global for EditPredictionStoreGlobal {}
 119
 120/// Configuration for using the raw Zeta2 endpoint.
 121/// When set, the client uses the raw endpoint and constructs the prompt itself.
 122/// The version is also used as the Baseten environment name (lowercased).
 123#[derive(Clone)]
 124pub struct Zeta2RawConfig {
 125    pub model_id: Option<String>,
 126    pub format: ZetaFormat,
 127}
 128
 129pub struct EditPredictionStore {
 130    client: Arc<Client>,
 131    user_store: Entity<UserStore>,
 132    llm_token: LlmApiToken,
 133    _llm_token_subscription: Subscription,
 134    projects: HashMap<EntityId, ProjectState>,
 135    update_required: bool,
 136    edit_prediction_model: EditPredictionModel,
 137    zeta2_raw_config: Option<Zeta2RawConfig>,
 138    pub sweep_ai: SweepAi,
 139    pub mercury: Mercury,
 140    pub ollama: Ollama,
 141    data_collection_choice: DataCollectionChoice,
 142    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 143    shown_predictions: VecDeque<EditPrediction>,
 144    rated_predictions: HashSet<EditPredictionId>,
 145}
 146
 147#[derive(Copy, Clone, Default, PartialEq, Eq)]
 148pub enum EditPredictionModel {
 149    #[default]
 150    Zeta1,
 151    Zeta2,
 152    Sweep,
 153    Mercury,
 154    Ollama,
 155}
 156
 157#[derive(Clone)]
 158pub struct EditPredictionModelInput {
 159    project: Entity<Project>,
 160    buffer: Entity<Buffer>,
 161    snapshot: BufferSnapshot,
 162    position: Anchor,
 163    events: Vec<Arc<zeta_prompt::Event>>,
 164    related_files: Vec<RelatedFile>,
 165    recent_paths: VecDeque<ProjectPath>,
 166    trigger: PredictEditsRequestTrigger,
 167    diagnostic_search_range: Range<Point>,
 168    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 169    pub user_actions: Vec<UserActionRecord>,
 170}
 171
 172#[derive(Debug)]
 173pub enum DebugEvent {
 174    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 175    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 176    EditPredictionStarted(EditPredictionStartedDebugEvent),
 177    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 178}
 179
 180#[derive(Debug)]
 181pub struct ContextRetrievalStartedDebugEvent {
 182    pub project_entity_id: EntityId,
 183    pub timestamp: Instant,
 184    pub search_prompt: String,
 185}
 186
 187#[derive(Debug)]
 188pub struct ContextRetrievalFinishedDebugEvent {
 189    pub project_entity_id: EntityId,
 190    pub timestamp: Instant,
 191    pub metadata: Vec<(&'static str, SharedString)>,
 192}
 193
 194#[derive(Debug)]
 195pub struct EditPredictionStartedDebugEvent {
 196    pub buffer: WeakEntity<Buffer>,
 197    pub position: Anchor,
 198    pub prompt: Option<String>,
 199}
 200
 201#[derive(Debug)]
 202pub struct EditPredictionFinishedDebugEvent {
 203    pub buffer: WeakEntity<Buffer>,
 204    pub position: Anchor,
 205    pub model_output: Option<String>,
 206}
 207
 208const USER_ACTION_HISTORY_SIZE: usize = 16;
 209
 210#[derive(Clone, Debug)]
 211pub struct UserActionRecord {
 212    pub action_type: UserActionType,
 213    pub buffer_id: EntityId,
 214    pub line_number: u32,
 215    pub offset: usize,
 216    pub timestamp_epoch_ms: u64,
 217}
 218
 219#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 220pub enum UserActionType {
 221    InsertChar,
 222    InsertSelection,
 223    DeleteChar,
 224    DeleteSelection,
 225    CursorMovement,
 226}
 227
 228/// An event with associated metadata for reconstructing buffer state.
 229#[derive(Clone)]
 230pub struct StoredEvent {
 231    pub event: Arc<zeta_prompt::Event>,
 232    pub old_snapshot: TextBufferSnapshot,
 233    pub edit_range: Range<Anchor>,
 234}
 235
 236impl StoredEvent {
 237    fn can_merge(
 238        &self,
 239        next_old_event: &&&StoredEvent,
 240        new_snapshot: &TextBufferSnapshot,
 241        last_edit_range: &Range<Anchor>,
 242    ) -> bool {
 243        // Events must be for the same buffer
 244        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 245            return false;
 246        }
 247
 248        let a_is_predicted = matches!(
 249            self.event.as_ref(),
 250            zeta_prompt::Event::BufferChange {
 251                predicted: true,
 252                ..
 253            }
 254        );
 255        let b_is_predicted = matches!(
 256            next_old_event.event.as_ref(),
 257            zeta_prompt::Event::BufferChange {
 258                predicted: true,
 259                ..
 260            }
 261        );
 262
 263        // If events come from the same source (both predicted or both manual) then
 264        // we would have coalesced them already.
 265        if a_is_predicted == b_is_predicted {
 266            return false;
 267        }
 268
 269        let left_range = self.edit_range.to_point(new_snapshot);
 270        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 271        let latest_range = last_edit_range.to_point(&new_snapshot);
 272
 273        // Events near to the latest edit are not merged if their sources differ.
 274        if lines_between_ranges(&left_range, &latest_range)
 275            .min(lines_between_ranges(&right_range, &latest_range))
 276            <= CHANGE_GROUPING_LINE_SPAN
 277        {
 278            return false;
 279        }
 280
 281        // Events that are distant from each other are not merged.
 282        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 283            return false;
 284        }
 285
 286        true
 287    }
 288}
 289
 290fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 291    if left.start > right.end {
 292        return left.start.row - right.end.row;
 293    }
 294    if right.start > left.end {
 295        return right.start.row - left.end.row;
 296    }
 297    0
 298}
 299
 300struct ProjectState {
 301    events: VecDeque<StoredEvent>,
 302    last_event: Option<LastEvent>,
 303    recent_paths: VecDeque<ProjectPath>,
 304    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 305    current_prediction: Option<CurrentEditPrediction>,
 306    next_pending_prediction_id: usize,
 307    pending_predictions: ArrayVec<PendingPrediction, 2>,
 308    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 309    last_prediction_refresh: Option<(EntityId, Instant)>,
 310    cancelled_predictions: HashSet<usize>,
 311    context: Entity<RelatedExcerptStore>,
 312    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 313    user_actions: VecDeque<UserActionRecord>,
 314    _subscriptions: [gpui::Subscription; 2],
 315    copilot: Option<Entity<Copilot>>,
 316}
 317
 318impl ProjectState {
 319    fn record_user_action(&mut self, action: UserActionRecord) {
 320        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 321            self.user_actions.pop_front();
 322        }
 323        self.user_actions.push_back(action);
 324    }
 325
 326    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 327        self.events
 328            .iter()
 329            .cloned()
 330            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 331                let (one, two) = event.split_by_pause();
 332                let one = one.finalize(&self.license_detection_watchers, cx);
 333                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 334                one.into_iter().chain(two)
 335            }))
 336            .collect()
 337    }
 338
 339    fn cancel_pending_prediction(
 340        &mut self,
 341        pending_prediction: PendingPrediction,
 342        cx: &mut Context<EditPredictionStore>,
 343    ) {
 344        self.cancelled_predictions.insert(pending_prediction.id);
 345
 346        if pending_prediction.drop_on_cancel {
 347            drop(pending_prediction.task);
 348        } else {
 349            cx.spawn(async move |this, cx| {
 350                let Some(prediction_id) = pending_prediction.task.await else {
 351                    return;
 352                };
 353
 354                this.update(cx, |this, cx| {
 355                    this.reject_prediction(
 356                        prediction_id,
 357                        EditPredictionRejectReason::Canceled,
 358                        false,
 359                        cx,
 360                    );
 361                })
 362                .ok();
 363            })
 364            .detach()
 365        }
 366    }
 367
 368    fn active_buffer(
 369        &self,
 370        project: &Entity<Project>,
 371        cx: &App,
 372    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 373        let project = project.read(cx);
 374        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 375        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 376        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 377        Some((active_buffer, registered_buffer.last_position))
 378    }
 379}
 380
 381#[derive(Debug, Clone)]
 382struct CurrentEditPrediction {
 383    pub requested_by: PredictionRequestedBy,
 384    pub prediction: EditPrediction,
 385    pub was_shown: bool,
 386    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 387}
 388
 389impl CurrentEditPrediction {
 390    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 391        let Some(new_edits) = self
 392            .prediction
 393            .interpolate(&self.prediction.buffer.read(cx))
 394        else {
 395            return false;
 396        };
 397
 398        if self.prediction.buffer != old_prediction.prediction.buffer {
 399            return true;
 400        }
 401
 402        let Some(old_edits) = old_prediction
 403            .prediction
 404            .interpolate(&old_prediction.prediction.buffer.read(cx))
 405        else {
 406            return true;
 407        };
 408
 409        let requested_by_buffer_id = self.requested_by.buffer_id();
 410
 411        // This reduces the occurrence of UI thrash from replacing edits
 412        //
 413        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 414        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 415            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 416            && old_edits.len() == 1
 417            && new_edits.len() == 1
 418        {
 419            let (old_range, old_text) = &old_edits[0];
 420            let (new_range, new_text) = &new_edits[0];
 421            new_range == old_range && new_text.starts_with(old_text.as_ref())
 422        } else {
 423            true
 424        }
 425    }
 426}
 427
 428#[derive(Debug, Clone)]
 429enum PredictionRequestedBy {
 430    DiagnosticsUpdate,
 431    Buffer(EntityId),
 432}
 433
 434impl PredictionRequestedBy {
 435    pub fn buffer_id(&self) -> Option<EntityId> {
 436        match self {
 437            PredictionRequestedBy::DiagnosticsUpdate => None,
 438            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 439        }
 440    }
 441}
 442
 443#[derive(Debug)]
 444struct PendingPrediction {
 445    id: usize,
 446    task: Task<Option<EditPredictionId>>,
 447    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 448    /// If false, the task is awaited to completion so rejection can be reported.
 449    drop_on_cancel: bool,
 450}
 451
 452/// A prediction from the perspective of a buffer.
 453#[derive(Debug)]
 454enum BufferEditPrediction<'a> {
 455    Local { prediction: &'a EditPrediction },
 456    Jump { prediction: &'a EditPrediction },
 457}
 458
 459#[cfg(test)]
 460impl std::ops::Deref for BufferEditPrediction<'_> {
 461    type Target = EditPrediction;
 462
 463    fn deref(&self) -> &Self::Target {
 464        match self {
 465            BufferEditPrediction::Local { prediction } => prediction,
 466            BufferEditPrediction::Jump { prediction } => prediction,
 467        }
 468    }
 469}
 470
 471struct RegisteredBuffer {
 472    file: Option<Arc<dyn File>>,
 473    snapshot: TextBufferSnapshot,
 474    last_position: Option<Anchor>,
 475    _subscriptions: [gpui::Subscription; 2],
 476}
 477
 478#[derive(Clone)]
 479struct LastEvent {
 480    old_snapshot: TextBufferSnapshot,
 481    new_snapshot: TextBufferSnapshot,
 482    old_file: Option<Arc<dyn File>>,
 483    new_file: Option<Arc<dyn File>>,
 484    edit_range: Option<Range<Anchor>>,
 485    predicted: bool,
 486    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 487    last_edit_time: Option<Instant>,
 488}
 489
 490impl LastEvent {
 491    pub fn finalize(
 492        &self,
 493        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 494        cx: &App,
 495    ) -> Option<StoredEvent> {
 496        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 497        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 498
 499        let in_open_source_repo =
 500            [self.new_file.as_ref(), self.old_file.as_ref()]
 501                .iter()
 502                .all(|file| {
 503                    file.is_some_and(|file| {
 504                        license_detection_watchers
 505                            .get(&file.worktree_id(cx))
 506                            .is_some_and(|watcher| watcher.is_project_open_source())
 507                    })
 508                });
 509
 510        let (diff, edit_range) =
 511            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 512
 513        if path == old_path && diff.is_empty() {
 514            None
 515        } else {
 516            Some(StoredEvent {
 517                event: Arc::new(zeta_prompt::Event::BufferChange {
 518                    old_path,
 519                    path,
 520                    diff,
 521                    in_open_source_repo,
 522                    predicted: self.predicted,
 523                }),
 524                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 525                    ..self.new_snapshot.anchor_before(edit_range.end),
 526                old_snapshot: self.old_snapshot.clone(),
 527            })
 528        }
 529    }
 530
 531    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 532        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 533            return (self.clone(), None);
 534        };
 535
 536        let before = LastEvent {
 537            old_snapshot: self.old_snapshot.clone(),
 538            new_snapshot: boundary_snapshot.clone(),
 539            old_file: self.old_file.clone(),
 540            new_file: self.new_file.clone(),
 541            edit_range: None,
 542            predicted: self.predicted,
 543            snapshot_after_last_editing_pause: None,
 544            last_edit_time: self.last_edit_time,
 545        };
 546
 547        let after = LastEvent {
 548            old_snapshot: boundary_snapshot.clone(),
 549            new_snapshot: self.new_snapshot.clone(),
 550            old_file: self.old_file.clone(),
 551            new_file: self.new_file.clone(),
 552            edit_range: None,
 553            predicted: self.predicted,
 554            snapshot_after_last_editing_pause: None,
 555            last_edit_time: self.last_edit_time,
 556        };
 557
 558        (before, Some(after))
 559    }
 560}
 561
 562pub(crate) fn compute_diff_between_snapshots(
 563    old_snapshot: &TextBufferSnapshot,
 564    new_snapshot: &TextBufferSnapshot,
 565) -> Option<(String, Range<Point>)> {
 566    let edits: Vec<Edit<usize>> = new_snapshot
 567        .edits_since::<usize>(&old_snapshot.version)
 568        .collect();
 569
 570    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 571
 572    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 573    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 574    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 575    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 576
 577    const CONTEXT_LINES: u32 = 3;
 578
 579    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 580    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 581    let old_context_end_row =
 582        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 583    let new_context_end_row =
 584        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 585
 586    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 587    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 588    let old_end_line_offset = old_snapshot
 589        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 590    let new_end_line_offset = new_snapshot
 591        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 592    let old_edit_range = old_start_line_offset..old_end_line_offset;
 593    let new_edit_range = new_start_line_offset..new_end_line_offset;
 594
 595    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 596    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 597
 598    let diff = language::unified_diff_with_offsets(
 599        &old_region_text,
 600        &new_region_text,
 601        old_context_start_row,
 602        new_context_start_row,
 603    );
 604
 605    Some((diff, new_start_point..new_end_point))
 606}
 607
 608fn buffer_path_with_id_fallback(
 609    file: Option<&Arc<dyn File>>,
 610    snapshot: &TextBufferSnapshot,
 611    cx: &App,
 612) -> Arc<Path> {
 613    if let Some(file) = file {
 614        file.full_path(cx).into()
 615    } else {
 616        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 617    }
 618}
 619
 620impl EditPredictionStore {
 621    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 622        cx.try_global::<EditPredictionStoreGlobal>()
 623            .map(|global| global.0.clone())
 624    }
 625
 626    pub fn global(
 627        client: &Arc<Client>,
 628        user_store: &Entity<UserStore>,
 629        cx: &mut App,
 630    ) -> Entity<Self> {
 631        cx.try_global::<EditPredictionStoreGlobal>()
 632            .map(|global| global.0.clone())
 633            .unwrap_or_else(|| {
 634                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 635                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 636                ep_store
 637            })
 638    }
 639
 640    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 641        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 642        let data_collection_choice = Self::load_data_collection_choice();
 643
 644        let llm_token = LlmApiToken::default();
 645
 646        let (reject_tx, reject_rx) = mpsc::unbounded();
 647        cx.background_spawn({
 648            let client = client.clone();
 649            let llm_token = llm_token.clone();
 650            let app_version = AppVersion::global(cx);
 651            let background_executor = cx.background_executor().clone();
 652            async move {
 653                Self::handle_rejected_predictions(
 654                    reject_rx,
 655                    client,
 656                    llm_token,
 657                    app_version,
 658                    background_executor,
 659                )
 660                .await
 661            }
 662        })
 663        .detach();
 664
 665        let this = Self {
 666            projects: HashMap::default(),
 667            client,
 668            user_store,
 669            llm_token,
 670            _llm_token_subscription: cx.subscribe(
 671                &refresh_llm_token_listener,
 672                |this, _listener, _event, cx| {
 673                    let client = this.client.clone();
 674                    let llm_token = this.llm_token.clone();
 675                    cx.spawn(async move |_this, _cx| {
 676                        llm_token.refresh(&client).await?;
 677                        anyhow::Ok(())
 678                    })
 679                    .detach_and_log_err(cx);
 680                },
 681            ),
 682            update_required: false,
 683            edit_prediction_model: EditPredictionModel::Zeta2,
 684            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 685            sweep_ai: SweepAi::new(cx),
 686            mercury: Mercury::new(cx),
 687            ollama: Ollama::new(),
 688
 689            data_collection_choice,
 690            reject_predictions_tx: reject_tx,
 691            rated_predictions: Default::default(),
 692            shown_predictions: Default::default(),
 693        };
 694
 695        this
 696    }
 697
 698    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 699        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 700        let format = ZetaFormat::parse(&version_str).ok()?;
 701        let model_id = env::var("ZED_ZETA_MODEL").ok();
 702        Some(Zeta2RawConfig { model_id, format })
 703    }
 704
 705    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 706        self.edit_prediction_model = model;
 707    }
 708
 709    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 710        self.zeta2_raw_config = Some(config);
 711    }
 712
 713    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 714        self.zeta2_raw_config.as_ref()
 715    }
 716
 717    pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
 718        use ui::IconName;
 719        match self.edit_prediction_model {
 720            EditPredictionModel::Sweep => {
 721                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 722                    .with_disabled(IconName::SweepAiDisabled)
 723                    .with_up(IconName::SweepAiUp)
 724                    .with_down(IconName::SweepAiDown)
 725                    .with_error(IconName::SweepAiError)
 726            }
 727            EditPredictionModel::Mercury => {
 728                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 729            }
 730            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 731                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 732                    .with_disabled(IconName::ZedPredictDisabled)
 733                    .with_up(IconName::ZedPredictUp)
 734                    .with_down(IconName::ZedPredictDown)
 735                    .with_error(IconName::ZedPredictError)
 736            }
 737            EditPredictionModel::Ollama => {
 738                edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 739            }
 740        }
 741    }
 742
 743    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 744        self.sweep_ai.api_token.read(cx).has_key()
 745    }
 746
 747    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 748        self.mercury.api_token.read(cx).has_key()
 749    }
 750
 751    pub fn clear_history(&mut self) {
 752        for project_state in self.projects.values_mut() {
 753            project_state.events.clear();
 754            project_state.last_event.take();
 755        }
 756    }
 757
 758    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 759        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 760            project_state.events.clear();
 761            project_state.last_event.take();
 762        }
 763    }
 764
 765    pub fn edit_history_for_project(
 766        &self,
 767        project: &Entity<Project>,
 768        cx: &App,
 769    ) -> Vec<StoredEvent> {
 770        self.projects
 771            .get(&project.entity_id())
 772            .map(|project_state| project_state.events(cx))
 773            .unwrap_or_default()
 774    }
 775
 776    pub fn context_for_project<'a>(
 777        &'a self,
 778        project: &Entity<Project>,
 779        cx: &'a mut App,
 780    ) -> Vec<RelatedFile> {
 781        self.projects
 782            .get(&project.entity_id())
 783            .map(|project_state| {
 784                project_state.context.update(cx, |context, cx| {
 785                    context
 786                        .related_files_with_buffers(cx)
 787                        .map(|(mut related_file, buffer)| {
 788                            related_file.in_open_source_repo = buffer
 789                                .read(cx)
 790                                .file()
 791                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 792                            related_file
 793                        })
 794                        .collect()
 795                })
 796            })
 797            .unwrap_or_default()
 798    }
 799
 800    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 801        self.projects
 802            .get(&project.entity_id())
 803            .and_then(|project| project.copilot.clone())
 804    }
 805
 806    pub fn start_copilot_for_project(
 807        &mut self,
 808        project: &Entity<Project>,
 809        cx: &mut Context<Self>,
 810    ) -> Option<Entity<Copilot>> {
 811        if DisableAiSettings::get(None, cx).disable_ai {
 812            return None;
 813        }
 814        let state = self.get_or_init_project(project, cx);
 815
 816        if state.copilot.is_some() {
 817            return state.copilot.clone();
 818        }
 819        let _project = project.clone();
 820        let project = project.read(cx);
 821
 822        let node = project.node_runtime().cloned();
 823        if let Some(node) = node {
 824            let next_id = project.languages().next_language_server_id();
 825            let fs = project.fs().clone();
 826
 827            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 828            state.copilot = Some(copilot.clone());
 829            Some(copilot)
 830        } else {
 831            None
 832        }
 833    }
 834
 835    pub fn context_for_project_with_buffers<'a>(
 836        &'a self,
 837        project: &Entity<Project>,
 838        cx: &'a mut App,
 839    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 840        self.projects
 841            .get(&project.entity_id())
 842            .map(|project| {
 843                project.context.update(cx, |context, cx| {
 844                    context.related_files_with_buffers(cx).collect()
 845                })
 846            })
 847            .unwrap_or_default()
 848    }
 849
 850    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 851        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
 852            self.user_store.read(cx).edit_prediction_usage()
 853        } else {
 854            None
 855        }
 856    }
 857
 858    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 859        self.get_or_init_project(project, cx);
 860    }
 861
 862    pub fn register_buffer(
 863        &mut self,
 864        buffer: &Entity<Buffer>,
 865        project: &Entity<Project>,
 866        cx: &mut Context<Self>,
 867    ) {
 868        let project_state = self.get_or_init_project(project, cx);
 869        Self::register_buffer_impl(project_state, buffer, project, cx);
 870    }
 871
 872    fn get_or_init_project(
 873        &mut self,
 874        project: &Entity<Project>,
 875        cx: &mut Context<Self>,
 876    ) -> &mut ProjectState {
 877        let entity_id = project.entity_id();
 878        self.projects
 879            .entry(entity_id)
 880            .or_insert_with(|| ProjectState {
 881                context: {
 882                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 883                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 884                        this.handle_excerpt_store_event(entity_id, event);
 885                    })
 886                    .detach();
 887                    related_excerpt_store
 888                },
 889                events: VecDeque::new(),
 890                last_event: None,
 891                recent_paths: VecDeque::new(),
 892                debug_tx: None,
 893                registered_buffers: HashMap::default(),
 894                current_prediction: None,
 895                cancelled_predictions: HashSet::default(),
 896                pending_predictions: ArrayVec::new(),
 897                next_pending_prediction_id: 0,
 898                last_prediction_refresh: None,
 899                license_detection_watchers: HashMap::default(),
 900                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 901                _subscriptions: [
 902                    cx.subscribe(&project, Self::handle_project_event),
 903                    cx.observe_release(&project, move |this, _, cx| {
 904                        this.projects.remove(&entity_id);
 905                        cx.notify();
 906                    }),
 907                ],
 908                copilot: None,
 909            })
 910    }
 911
 912    pub fn remove_project(&mut self, project: &Entity<Project>) {
 913        self.projects.remove(&project.entity_id());
 914    }
 915
 916    fn handle_excerpt_store_event(
 917        &mut self,
 918        project_entity_id: EntityId,
 919        event: &RelatedExcerptStoreEvent,
 920    ) {
 921        if let Some(project_state) = self.projects.get(&project_entity_id) {
 922            if let Some(debug_tx) = project_state.debug_tx.clone() {
 923                match event {
 924                    RelatedExcerptStoreEvent::StartedRefresh => {
 925                        debug_tx
 926                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 927                                ContextRetrievalStartedDebugEvent {
 928                                    project_entity_id: project_entity_id,
 929                                    timestamp: Instant::now(),
 930                                    search_prompt: String::new(),
 931                                },
 932                            ))
 933                            .ok();
 934                    }
 935                    RelatedExcerptStoreEvent::FinishedRefresh {
 936                        cache_hit_count,
 937                        cache_miss_count,
 938                        mean_definition_latency,
 939                        max_definition_latency,
 940                    } => {
 941                        debug_tx
 942                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 943                                ContextRetrievalFinishedDebugEvent {
 944                                    project_entity_id: project_entity_id,
 945                                    timestamp: Instant::now(),
 946                                    metadata: vec![
 947                                        (
 948                                            "Cache Hits",
 949                                            format!(
 950                                                "{}/{}",
 951                                                cache_hit_count,
 952                                                cache_hit_count + cache_miss_count
 953                                            )
 954                                            .into(),
 955                                        ),
 956                                        (
 957                                            "Max LSP Time",
 958                                            format!("{} ms", max_definition_latency.as_millis())
 959                                                .into(),
 960                                        ),
 961                                        (
 962                                            "Mean LSP Time",
 963                                            format!("{} ms", mean_definition_latency.as_millis())
 964                                                .into(),
 965                                        ),
 966                                    ],
 967                                },
 968                            ))
 969                            .ok();
 970                    }
 971                }
 972            }
 973        }
 974    }
 975
 976    pub fn debug_info(
 977        &mut self,
 978        project: &Entity<Project>,
 979        cx: &mut Context<Self>,
 980    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 981        let project_state = self.get_or_init_project(project, cx);
 982        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 983        project_state.debug_tx = Some(debug_watch_tx);
 984        debug_watch_rx
 985    }
 986
 987    fn handle_project_event(
 988        &mut self,
 989        project: Entity<Project>,
 990        event: &project::Event,
 991        cx: &mut Context<Self>,
 992    ) {
 993        // TODO [zeta2] init with recent paths
 994        match event {
 995            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 996                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 997                    return;
 998                };
 999                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1000                if let Some(path) = path {
1001                    if let Some(ix) = project_state
1002                        .recent_paths
1003                        .iter()
1004                        .position(|probe| probe == &path)
1005                    {
1006                        project_state.recent_paths.remove(ix);
1007                    }
1008                    project_state.recent_paths.push_front(path);
1009                }
1010            }
1011            project::Event::DiagnosticsUpdated { .. } => {
1012                if cx.has_flag::<Zeta2FeatureFlag>() {
1013                    self.refresh_prediction_from_diagnostics(project, cx);
1014                }
1015            }
1016            _ => (),
1017        }
1018    }
1019
1020    fn register_buffer_impl<'a>(
1021        project_state: &'a mut ProjectState,
1022        buffer: &Entity<Buffer>,
1023        project: &Entity<Project>,
1024        cx: &mut Context<Self>,
1025    ) -> &'a mut RegisteredBuffer {
1026        let buffer_id = buffer.entity_id();
1027
1028        if let Some(file) = buffer.read(cx).file() {
1029            let worktree_id = file.worktree_id(cx);
1030            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1031                project_state
1032                    .license_detection_watchers
1033                    .entry(worktree_id)
1034                    .or_insert_with(|| {
1035                        let project_entity_id = project.entity_id();
1036                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1037                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1038                            else {
1039                                return;
1040                            };
1041                            project_state
1042                                .license_detection_watchers
1043                                .remove(&worktree_id);
1044                        })
1045                        .detach();
1046                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1047                    });
1048            }
1049        }
1050
1051        match project_state.registered_buffers.entry(buffer_id) {
1052            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1053            hash_map::Entry::Vacant(entry) => {
1054                let buf = buffer.read(cx);
1055                let snapshot = buf.text_snapshot();
1056                let file = buf.file().cloned();
1057                let project_entity_id = project.entity_id();
1058                entry.insert(RegisteredBuffer {
1059                    snapshot,
1060                    file,
1061                    last_position: None,
1062                    _subscriptions: [
1063                        cx.subscribe(buffer, {
1064                            let project = project.downgrade();
1065                            move |this, buffer, event, cx| {
1066                                if let language::BufferEvent::Edited = event
1067                                    && let Some(project) = project.upgrade()
1068                                {
1069                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1070                                }
1071                            }
1072                        }),
1073                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1074                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1075                            else {
1076                                return;
1077                            };
1078                            project_state.registered_buffers.remove(&buffer_id);
1079                        }),
1080                    ],
1081                })
1082            }
1083        }
1084    }
1085
1086    fn report_changes_for_buffer(
1087        &mut self,
1088        buffer: &Entity<Buffer>,
1089        project: &Entity<Project>,
1090        is_predicted: bool,
1091        cx: &mut Context<Self>,
1092    ) {
1093        let project_state = self.get_or_init_project(project, cx);
1094        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1095
1096        let buf = buffer.read(cx);
1097        let new_file = buf.file().cloned();
1098        let new_snapshot = buf.text_snapshot();
1099        if new_snapshot.version == registered_buffer.snapshot.version {
1100            return;
1101        }
1102
1103        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1104        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1105        let mut num_edits = 0usize;
1106        let mut total_deleted = 0usize;
1107        let mut total_inserted = 0usize;
1108        let mut edit_range: Option<Range<Anchor>> = None;
1109        let mut last_offset: Option<usize> = None;
1110
1111        for (edit, anchor_range) in
1112            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1113        {
1114            num_edits += 1;
1115            total_deleted += edit.old.len();
1116            total_inserted += edit.new.len();
1117            edit_range = Some(match edit_range {
1118                None => anchor_range,
1119                Some(acc) => acc.start..anchor_range.end,
1120            });
1121            last_offset = Some(edit.new.end);
1122        }
1123
1124        let Some(edit_range) = edit_range else {
1125            return;
1126        };
1127
1128        let action_type = match (total_deleted, total_inserted, num_edits) {
1129            (0, ins, n) if ins == n => UserActionType::InsertChar,
1130            (0, _, _) => UserActionType::InsertSelection,
1131            (del, 0, n) if del == n => UserActionType::DeleteChar,
1132            (_, 0, _) => UserActionType::DeleteSelection,
1133            (_, ins, n) if ins == n => UserActionType::InsertChar,
1134            (_, _, _) => UserActionType::InsertSelection,
1135        };
1136
1137        if let Some(offset) = last_offset {
1138            let point = new_snapshot.offset_to_point(offset);
1139            let timestamp_epoch_ms = SystemTime::now()
1140                .duration_since(UNIX_EPOCH)
1141                .map(|d| d.as_millis() as u64)
1142                .unwrap_or(0);
1143            project_state.record_user_action(UserActionRecord {
1144                action_type,
1145                buffer_id: buffer.entity_id(),
1146                line_number: point.row,
1147                offset,
1148                timestamp_epoch_ms,
1149            });
1150        }
1151
1152        let events = &mut project_state.events;
1153
1154        let now = cx.background_executor().now();
1155        if let Some(last_event) = project_state.last_event.as_mut() {
1156            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1157                == last_event.new_snapshot.remote_id()
1158                && old_snapshot.version == last_event.new_snapshot.version;
1159
1160            let prediction_source_changed = is_predicted != last_event.predicted;
1161
1162            let should_coalesce = is_next_snapshot_of_same_buffer
1163                && !prediction_source_changed
1164                && last_event
1165                    .edit_range
1166                    .as_ref()
1167                    .is_some_and(|last_edit_range| {
1168                        lines_between_ranges(
1169                            &edit_range.to_point(&new_snapshot),
1170                            &last_edit_range.to_point(&new_snapshot),
1171                        ) <= CHANGE_GROUPING_LINE_SPAN
1172                    });
1173
1174            if should_coalesce {
1175                let pause_elapsed = last_event
1176                    .last_edit_time
1177                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1178                    .unwrap_or(false);
1179                if pause_elapsed {
1180                    last_event.snapshot_after_last_editing_pause =
1181                        Some(last_event.new_snapshot.clone());
1182                }
1183
1184                last_event.edit_range = Some(edit_range);
1185                last_event.new_snapshot = new_snapshot;
1186                last_event.last_edit_time = Some(now);
1187                return;
1188            }
1189        }
1190
1191        if let Some(event) = project_state.last_event.take() {
1192            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1193                if events.len() + 1 >= EVENT_COUNT_MAX {
1194                    events.pop_front();
1195                }
1196                events.push_back(event);
1197            }
1198        }
1199
1200        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1201
1202        project_state.last_event = Some(LastEvent {
1203            old_file,
1204            new_file,
1205            old_snapshot,
1206            new_snapshot,
1207            edit_range: Some(edit_range),
1208            predicted: is_predicted,
1209            snapshot_after_last_editing_pause: None,
1210            last_edit_time: Some(now),
1211        });
1212    }
1213
1214    fn prediction_at(
1215        &mut self,
1216        buffer: &Entity<Buffer>,
1217        position: Option<language::Anchor>,
1218        project: &Entity<Project>,
1219        cx: &App,
1220    ) -> Option<BufferEditPrediction<'_>> {
1221        let project_state = self.projects.get_mut(&project.entity_id())?;
1222        if let Some(position) = position
1223            && let Some(buffer) = project_state
1224                .registered_buffers
1225                .get_mut(&buffer.entity_id())
1226        {
1227            buffer.last_position = Some(position);
1228        }
1229
1230        let CurrentEditPrediction {
1231            requested_by,
1232            prediction,
1233            ..
1234        } = project_state.current_prediction.as_ref()?;
1235
1236        if prediction.targets_buffer(buffer.read(cx)) {
1237            Some(BufferEditPrediction::Local { prediction })
1238        } else {
1239            let show_jump = match requested_by {
1240                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1241                    requested_by_buffer_id == &buffer.entity_id()
1242                }
1243                PredictionRequestedBy::DiagnosticsUpdate => true,
1244            };
1245
1246            if show_jump {
1247                Some(BufferEditPrediction::Jump { prediction })
1248            } else {
1249                None
1250            }
1251        }
1252    }
1253
1254    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1255        let Some(current_prediction) = self
1256            .projects
1257            .get_mut(&project.entity_id())
1258            .and_then(|project_state| project_state.current_prediction.take())
1259        else {
1260            return;
1261        };
1262
1263        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1264
1265        // can't hold &mut project_state ref across report_changes_for_buffer_call
1266        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1267            return;
1268        };
1269
1270        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1271            project_state.cancel_pending_prediction(pending_prediction, cx);
1272        }
1273
1274        match self.edit_prediction_model {
1275            EditPredictionModel::Sweep => {
1276                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1277            }
1278            EditPredictionModel::Mercury => {
1279                mercury::edit_prediction_accepted(
1280                    current_prediction.prediction.id,
1281                    self.client.http_client(),
1282                    cx,
1283                );
1284            }
1285            EditPredictionModel::Ollama => {}
1286            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1287                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1288            }
1289        }
1290    }
1291
1292    async fn handle_rejected_predictions(
1293        rx: UnboundedReceiver<EditPredictionRejection>,
1294        client: Arc<Client>,
1295        llm_token: LlmApiToken,
1296        app_version: Version,
1297        background_executor: BackgroundExecutor,
1298    ) {
1299        let mut rx = std::pin::pin!(rx.peekable());
1300        let mut batched = Vec::new();
1301
1302        while let Some(rejection) = rx.next().await {
1303            batched.push(rejection);
1304
1305            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1306                select_biased! {
1307                    next = rx.as_mut().peek().fuse() => {
1308                        if next.is_some() {
1309                            continue;
1310                        }
1311                    }
1312                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1313                }
1314            }
1315
1316            let url = client
1317                .http_client()
1318                .build_zed_llm_url("/predict_edits/reject", &[])
1319                .unwrap();
1320
1321            let flush_count = batched
1322                .len()
1323                // in case items have accumulated after failure
1324                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1325            let start = batched.len() - flush_count;
1326
1327            let body = RejectEditPredictionsBodyRef {
1328                rejections: &batched[start..],
1329            };
1330
1331            let result = Self::send_api_request::<()>(
1332                |builder| {
1333                    let req = builder
1334                        .uri(url.as_ref())
1335                        .body(serde_json::to_string(&body)?.into());
1336                    anyhow::Ok(req?)
1337                },
1338                client.clone(),
1339                llm_token.clone(),
1340                app_version.clone(),
1341                true,
1342            )
1343            .await;
1344
1345            if result.log_err().is_some() {
1346                batched.drain(start..);
1347            }
1348        }
1349    }
1350
1351    fn reject_current_prediction(
1352        &mut self,
1353        reason: EditPredictionRejectReason,
1354        project: &Entity<Project>,
1355        cx: &App,
1356    ) {
1357        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1358            project_state.pending_predictions.clear();
1359            if let Some(prediction) = project_state.current_prediction.take() {
1360                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1361            }
1362        };
1363    }
1364
1365    fn did_show_current_prediction(
1366        &mut self,
1367        project: &Entity<Project>,
1368        display_type: edit_prediction_types::SuggestionDisplayType,
1369        cx: &mut Context<Self>,
1370    ) {
1371        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1372            return;
1373        };
1374
1375        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1376            return;
1377        };
1378
1379        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1380        let previous_shown_with = current_prediction.shown_with;
1381
1382        if previous_shown_with.is_none() || !is_jump {
1383            current_prediction.shown_with = Some(display_type);
1384        }
1385
1386        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1387
1388        if is_first_non_jump_show {
1389            current_prediction.was_shown = true;
1390        }
1391
1392        let display_type_changed = previous_shown_with != Some(display_type);
1393
1394        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1395            sweep_ai::edit_prediction_shown(
1396                &self.sweep_ai,
1397                self.client.clone(),
1398                &current_prediction.prediction,
1399                display_type,
1400                cx,
1401            );
1402        }
1403
1404        if is_first_non_jump_show {
1405            self.shown_predictions
1406                .push_front(current_prediction.prediction.clone());
1407            if self.shown_predictions.len() > 50 {
1408                let completion = self.shown_predictions.pop_back().unwrap();
1409                self.rated_predictions.remove(&completion.id);
1410            }
1411        }
1412    }
1413
1414    fn reject_prediction(
1415        &mut self,
1416        prediction_id: EditPredictionId,
1417        reason: EditPredictionRejectReason,
1418        was_shown: bool,
1419        cx: &App,
1420    ) {
1421        match self.edit_prediction_model {
1422            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1423                self.reject_predictions_tx
1424                    .unbounded_send(EditPredictionRejection {
1425                        request_id: prediction_id.to_string(),
1426                        reason,
1427                        was_shown,
1428                    })
1429                    .log_err();
1430            }
1431            EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1432            EditPredictionModel::Mercury => {
1433                mercury::edit_prediction_rejected(
1434                    prediction_id,
1435                    was_shown,
1436                    reason,
1437                    self.client.http_client(),
1438                    cx,
1439                );
1440            }
1441        }
1442    }
1443
1444    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1445        self.projects
1446            .get(&project.entity_id())
1447            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1448    }
1449
1450    pub fn refresh_prediction_from_buffer(
1451        &mut self,
1452        project: Entity<Project>,
1453        buffer: Entity<Buffer>,
1454        position: language::Anchor,
1455        cx: &mut Context<Self>,
1456    ) {
1457        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1458            let Some(request_task) = this
1459                .update(cx, |this, cx| {
1460                    this.request_prediction(
1461                        &project,
1462                        &buffer,
1463                        position,
1464                        PredictEditsRequestTrigger::Other,
1465                        cx,
1466                    )
1467                })
1468                .log_err()
1469            else {
1470                return Task::ready(anyhow::Ok(None));
1471            };
1472
1473            cx.spawn(async move |_cx| {
1474                request_task.await.map(|prediction_result| {
1475                    prediction_result.map(|prediction_result| {
1476                        (
1477                            prediction_result,
1478                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1479                        )
1480                    })
1481                })
1482            })
1483        })
1484    }
1485
1486    pub fn refresh_prediction_from_diagnostics(
1487        &mut self,
1488        project: Entity<Project>,
1489        cx: &mut Context<Self>,
1490    ) {
1491        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1492            return;
1493        };
1494
1495        // Prefer predictions from buffer
1496        if project_state.current_prediction.is_some() {
1497            return;
1498        };
1499
1500        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1501            let Some((active_buffer, snapshot, cursor_point)) = this
1502                .read_with(cx, |this, cx| {
1503                    let project_state = this.projects.get(&project.entity_id())?;
1504                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1505                    let snapshot = buffer.read(cx).snapshot();
1506
1507                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1508                        return None;
1509                    }
1510
1511                    let cursor_point = position
1512                        .map(|pos| pos.to_point(&snapshot))
1513                        .unwrap_or_default();
1514
1515                    Some((buffer, snapshot, cursor_point))
1516                })
1517                .log_err()
1518                .flatten()
1519            else {
1520                return Task::ready(anyhow::Ok(None));
1521            };
1522
1523            cx.spawn(async move |cx| {
1524                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1525                    active_buffer,
1526                    &snapshot,
1527                    Default::default(),
1528                    cursor_point,
1529                    &project,
1530                    cx,
1531                )
1532                .await?
1533                else {
1534                    return anyhow::Ok(None);
1535                };
1536
1537                let Some(prediction_result) = this
1538                    .update(cx, |this, cx| {
1539                        this.request_prediction(
1540                            &project,
1541                            &jump_buffer,
1542                            jump_position,
1543                            PredictEditsRequestTrigger::Diagnostics,
1544                            cx,
1545                        )
1546                    })?
1547                    .await?
1548                else {
1549                    return anyhow::Ok(None);
1550                };
1551
1552                this.update(cx, |this, cx| {
1553                    Some((
1554                        if this
1555                            .get_or_init_project(&project, cx)
1556                            .current_prediction
1557                            .is_none()
1558                        {
1559                            prediction_result
1560                        } else {
1561                            EditPredictionResult {
1562                                id: prediction_result.id,
1563                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1564                            }
1565                        },
1566                        PredictionRequestedBy::DiagnosticsUpdate,
1567                    ))
1568                })
1569            })
1570        });
1571    }
1572
1573    fn predictions_enabled_at(
1574        snapshot: &BufferSnapshot,
1575        position: Option<language::Anchor>,
1576        cx: &App,
1577    ) -> bool {
1578        let file = snapshot.file();
1579        let all_settings = all_language_settings(file, cx);
1580        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1581            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1582        {
1583            return false;
1584        }
1585
1586        if let Some(last_position) = position {
1587            let settings = snapshot.settings_at(last_position, cx);
1588
1589            if !settings.edit_predictions_disabled_in.is_empty()
1590                && let Some(scope) = snapshot.language_scope_at(last_position)
1591                && let Some(scope_name) = scope.override_name()
1592                && settings
1593                    .edit_predictions_disabled_in
1594                    .iter()
1595                    .any(|s| s == scope_name)
1596            {
1597                return false;
1598            }
1599        }
1600
1601        true
1602    }
1603
1604    #[cfg(not(test))]
1605    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1606    #[cfg(test)]
1607    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1608
1609    fn queue_prediction_refresh(
1610        &mut self,
1611        project: Entity<Project>,
1612        throttle_entity: EntityId,
1613        cx: &mut Context<Self>,
1614        do_refresh: impl FnOnce(
1615            WeakEntity<Self>,
1616            &mut AsyncApp,
1617        )
1618            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1619        + 'static,
1620    ) {
1621        let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1622        let drop_on_cancel = is_ollama;
1623        let max_pending_predictions = if is_ollama { 1 } else { 2 };
1624        let project_state = self.get_or_init_project(&project, cx);
1625        let pending_prediction_id = project_state.next_pending_prediction_id;
1626        project_state.next_pending_prediction_id += 1;
1627        let last_request = project_state.last_prediction_refresh;
1628
1629        let task = cx.spawn(async move |this, cx| {
1630            if let Some((last_entity, last_timestamp)) = last_request
1631                && throttle_entity == last_entity
1632                && let Some(timeout) =
1633                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1634            {
1635                cx.background_executor().timer(timeout).await;
1636            }
1637
1638            // If this task was cancelled before the throttle timeout expired,
1639            // do not perform a request.
1640            let mut is_cancelled = true;
1641            this.update(cx, |this, cx| {
1642                let project_state = this.get_or_init_project(&project, cx);
1643                if !project_state
1644                    .cancelled_predictions
1645                    .remove(&pending_prediction_id)
1646                {
1647                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1648                    is_cancelled = false;
1649                }
1650            })
1651            .ok();
1652            if is_cancelled {
1653                return None;
1654            }
1655
1656            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1657            let new_prediction_id = new_prediction_result
1658                .as_ref()
1659                .map(|(prediction, _)| prediction.id.clone());
1660
1661            // When a prediction completes, remove it from the pending list, and cancel
1662            // any pending predictions that were enqueued before it.
1663            this.update(cx, |this, cx| {
1664                let project_state = this.get_or_init_project(&project, cx);
1665
1666                let is_cancelled = project_state
1667                    .cancelled_predictions
1668                    .remove(&pending_prediction_id);
1669
1670                let new_current_prediction = if !is_cancelled
1671                    && let Some((prediction_result, requested_by)) = new_prediction_result
1672                {
1673                    match prediction_result.prediction {
1674                        Ok(prediction) => {
1675                            let new_prediction = CurrentEditPrediction {
1676                                requested_by,
1677                                prediction,
1678                                was_shown: false,
1679                                shown_with: None,
1680                            };
1681
1682                            if let Some(current_prediction) =
1683                                project_state.current_prediction.as_ref()
1684                            {
1685                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1686                                {
1687                                    this.reject_current_prediction(
1688                                        EditPredictionRejectReason::Replaced,
1689                                        &project,
1690                                        cx,
1691                                    );
1692
1693                                    Some(new_prediction)
1694                                } else {
1695                                    this.reject_prediction(
1696                                        new_prediction.prediction.id,
1697                                        EditPredictionRejectReason::CurrentPreferred,
1698                                        false,
1699                                        cx,
1700                                    );
1701                                    None
1702                                }
1703                            } else {
1704                                Some(new_prediction)
1705                            }
1706                        }
1707                        Err(reject_reason) => {
1708                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1709                            None
1710                        }
1711                    }
1712                } else {
1713                    None
1714                };
1715
1716                let project_state = this.get_or_init_project(&project, cx);
1717
1718                if let Some(new_prediction) = new_current_prediction {
1719                    project_state.current_prediction = Some(new_prediction);
1720                }
1721
1722                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1723                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1724                    if pending_prediction.id == pending_prediction_id {
1725                        pending_predictions.remove(ix);
1726                        for pending_prediction in pending_predictions.drain(0..ix) {
1727                            project_state.cancel_pending_prediction(pending_prediction, cx)
1728                        }
1729                        break;
1730                    }
1731                }
1732                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1733                cx.notify();
1734            })
1735            .ok();
1736
1737            new_prediction_id
1738        });
1739
1740        if project_state.pending_predictions.len() < max_pending_predictions {
1741            project_state.pending_predictions.push(PendingPrediction {
1742                id: pending_prediction_id,
1743                task,
1744                drop_on_cancel,
1745            });
1746        } else {
1747            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1748            project_state.pending_predictions.push(PendingPrediction {
1749                id: pending_prediction_id,
1750                task,
1751                drop_on_cancel,
1752            });
1753            project_state.cancel_pending_prediction(pending_prediction, cx);
1754        }
1755    }
1756
1757    pub fn request_prediction(
1758        &mut self,
1759        project: &Entity<Project>,
1760        active_buffer: &Entity<Buffer>,
1761        position: language::Anchor,
1762        trigger: PredictEditsRequestTrigger,
1763        cx: &mut Context<Self>,
1764    ) -> Task<Result<Option<EditPredictionResult>>> {
1765        self.request_prediction_internal(
1766            project.clone(),
1767            active_buffer.clone(),
1768            position,
1769            trigger,
1770            cx.has_flag::<Zeta2FeatureFlag>(),
1771            cx,
1772        )
1773    }
1774
1775    fn request_prediction_internal(
1776        &mut self,
1777        project: Entity<Project>,
1778        active_buffer: Entity<Buffer>,
1779        position: language::Anchor,
1780        trigger: PredictEditsRequestTrigger,
1781        allow_jump: bool,
1782        cx: &mut Context<Self>,
1783    ) -> Task<Result<Option<EditPredictionResult>>> {
1784        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1785
1786        self.get_or_init_project(&project, cx);
1787        let project_state = self.projects.get(&project.entity_id()).unwrap();
1788        let stored_events = project_state.events(cx);
1789        let has_events = !stored_events.is_empty();
1790        let events: Vec<Arc<zeta_prompt::Event>> =
1791            stored_events.into_iter().map(|e| e.event).collect();
1792        let debug_tx = project_state.debug_tx.clone();
1793
1794        let snapshot = active_buffer.read(cx).snapshot();
1795        let cursor_point = position.to_point(&snapshot);
1796        let current_offset = position.to_offset(&snapshot);
1797
1798        let mut user_actions: Vec<UserActionRecord> =
1799            project_state.user_actions.iter().cloned().collect();
1800
1801        if let Some(last_action) = user_actions.last() {
1802            if last_action.buffer_id == active_buffer.entity_id()
1803                && current_offset != last_action.offset
1804            {
1805                let timestamp_epoch_ms = SystemTime::now()
1806                    .duration_since(UNIX_EPOCH)
1807                    .map(|d| d.as_millis() as u64)
1808                    .unwrap_or(0);
1809                user_actions.push(UserActionRecord {
1810                    action_type: UserActionType::CursorMovement,
1811                    buffer_id: active_buffer.entity_id(),
1812                    line_number: cursor_point.row,
1813                    offset: current_offset,
1814                    timestamp_epoch_ms,
1815                });
1816            }
1817        }
1818        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1819        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1820        let diagnostic_search_range =
1821            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1822
1823        let related_files = self.context_for_project(&project, cx);
1824
1825        let inputs = EditPredictionModelInput {
1826            project: project.clone(),
1827            buffer: active_buffer.clone(),
1828            snapshot: snapshot.clone(),
1829            position,
1830            events,
1831            related_files,
1832            recent_paths: project_state.recent_paths.clone(),
1833            trigger,
1834            diagnostic_search_range: diagnostic_search_range.clone(),
1835            debug_tx,
1836            user_actions,
1837        };
1838
1839        let task = match &self.edit_prediction_model {
1840            EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
1841                self,
1842                inputs,
1843                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1844                cx,
1845            ),
1846            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1847                self,
1848                inputs,
1849                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1850                cx,
1851            ),
1852            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1853            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1854            EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1855        };
1856
1857        cx.spawn(async move |this, cx| {
1858            let prediction = task.await?;
1859
1860            if prediction.is_none() && allow_jump {
1861                let cursor_point = position.to_point(&snapshot);
1862                if has_events
1863                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1864                        active_buffer.clone(),
1865                        &snapshot,
1866                        diagnostic_search_range,
1867                        cursor_point,
1868                        &project,
1869                        cx,
1870                    )
1871                    .await?
1872                {
1873                    return this
1874                        .update(cx, |this, cx| {
1875                            this.request_prediction_internal(
1876                                project,
1877                                jump_buffer,
1878                                jump_position,
1879                                trigger,
1880                                false,
1881                                cx,
1882                            )
1883                        })?
1884                        .await;
1885                }
1886
1887                return anyhow::Ok(None);
1888            }
1889
1890            Ok(prediction)
1891        })
1892    }
1893
1894    async fn next_diagnostic_location(
1895        active_buffer: Entity<Buffer>,
1896        active_buffer_snapshot: &BufferSnapshot,
1897        active_buffer_diagnostic_search_range: Range<Point>,
1898        active_buffer_cursor_point: Point,
1899        project: &Entity<Project>,
1900        cx: &mut AsyncApp,
1901    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1902        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1903        let mut jump_location = active_buffer_snapshot
1904            .diagnostic_groups(None)
1905            .into_iter()
1906            .filter_map(|(_, group)| {
1907                let range = &group.entries[group.primary_ix]
1908                    .range
1909                    .to_point(&active_buffer_snapshot);
1910                if range.overlaps(&active_buffer_diagnostic_search_range) {
1911                    None
1912                } else {
1913                    Some(range.start)
1914                }
1915            })
1916            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1917            .map(|position| {
1918                (
1919                    active_buffer.clone(),
1920                    active_buffer_snapshot.anchor_before(position),
1921                )
1922            });
1923
1924        if jump_location.is_none() {
1925            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1926                let file = buffer.file()?;
1927
1928                Some(ProjectPath {
1929                    worktree_id: file.worktree_id(cx),
1930                    path: file.path().clone(),
1931                })
1932            });
1933
1934            let buffer_task = project.update(cx, |project, cx| {
1935                let (path, _, _) = project
1936                    .diagnostic_summaries(false, cx)
1937                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1938                    .max_by_key(|(path, _, _)| {
1939                        // find the buffer with errors that shares most parent directories
1940                        path.path
1941                            .components()
1942                            .zip(
1943                                active_buffer_path
1944                                    .as_ref()
1945                                    .map(|p| p.path.components())
1946                                    .unwrap_or_default(),
1947                            )
1948                            .take_while(|(a, b)| a == b)
1949                            .count()
1950                    })?;
1951
1952                Some(project.open_buffer(path, cx))
1953            });
1954
1955            if let Some(buffer_task) = buffer_task {
1956                let closest_buffer = buffer_task.await?;
1957
1958                jump_location = closest_buffer
1959                    .read_with(cx, |buffer, _cx| {
1960                        buffer
1961                            .buffer_diagnostics(None)
1962                            .into_iter()
1963                            .min_by_key(|entry| entry.diagnostic.severity)
1964                            .map(|entry| entry.range.start)
1965                    })
1966                    .map(|position| (closest_buffer, position));
1967            }
1968        }
1969
1970        anyhow::Ok(jump_location)
1971    }
1972
1973    async fn send_raw_llm_request(
1974        request: RawCompletionRequest,
1975        client: Arc<Client>,
1976        custom_url: Option<Arc<Url>>,
1977        llm_token: LlmApiToken,
1978        app_version: Version,
1979    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1980        let url = if let Some(custom_url) = custom_url {
1981            custom_url.as_ref().clone()
1982        } else {
1983            client
1984                .http_client()
1985                .build_zed_llm_url("/predict_edits/raw", &[])?
1986        };
1987
1988        Self::send_api_request(
1989            |builder| {
1990                let req = builder
1991                    .uri(url.as_ref())
1992                    .body(serde_json::to_string(&request)?.into());
1993                Ok(req?)
1994            },
1995            client,
1996            llm_token,
1997            app_version,
1998            true,
1999        )
2000        .await
2001    }
2002
2003    pub(crate) async fn send_v3_request(
2004        input: ZetaPromptInput,
2005        client: Arc<Client>,
2006        llm_token: LlmApiToken,
2007        app_version: Version,
2008        trigger: PredictEditsRequestTrigger,
2009    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2010        let url = client
2011            .http_client()
2012            .build_zed_llm_url("/predict_edits/v3", &[])?;
2013
2014        let request = PredictEditsV3Request { input, trigger };
2015
2016        let json_bytes = serde_json::to_vec(&request)?;
2017        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2018
2019        Self::send_api_request(
2020            |builder| {
2021                let req = builder
2022                    .uri(url.as_ref())
2023                    .header("Content-Encoding", "zstd")
2024                    .body(compressed.clone().into());
2025                Ok(req?)
2026            },
2027            client,
2028            llm_token,
2029            app_version,
2030            true,
2031        )
2032        .await
2033    }
2034
2035    fn handle_api_response<T>(
2036        this: &WeakEntity<Self>,
2037        response: Result<(T, Option<EditPredictionUsage>)>,
2038        cx: &mut gpui::AsyncApp,
2039    ) -> Result<T> {
2040        match response {
2041            Ok((data, usage)) => {
2042                if let Some(usage) = usage {
2043                    this.update(cx, |this, cx| {
2044                        this.user_store.update(cx, |user_store, cx| {
2045                            user_store.update_edit_prediction_usage(usage, cx);
2046                        });
2047                    })
2048                    .ok();
2049                }
2050                Ok(data)
2051            }
2052            Err(err) => {
2053                if err.is::<ZedUpdateRequiredError>() {
2054                    cx.update(|cx| {
2055                        this.update(cx, |this, _cx| {
2056                            this.update_required = true;
2057                        })
2058                        .ok();
2059
2060                        let error_message: SharedString = err.to_string().into();
2061                        show_app_notification(
2062                            NotificationId::unique::<ZedUpdateRequiredError>(),
2063                            cx,
2064                            move |cx| {
2065                                cx.new(|cx| {
2066                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2067                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2068                                })
2069                            },
2070                        );
2071                    });
2072                }
2073                Err(err)
2074            }
2075        }
2076    }
2077
2078    async fn send_api_request<Res>(
2079        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2080        client: Arc<Client>,
2081        llm_token: LlmApiToken,
2082        app_version: Version,
2083        require_auth: bool,
2084    ) -> Result<(Res, Option<EditPredictionUsage>)>
2085    where
2086        Res: DeserializeOwned,
2087    {
2088        let http_client = client.http_client();
2089
2090        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2091            Some(custom_token)
2092        } else if require_auth {
2093            Some(llm_token.acquire(&client).await?)
2094        } else {
2095            llm_token.acquire(&client).await.ok()
2096        };
2097        let mut did_retry = false;
2098
2099        loop {
2100            let request_builder = http_client::Request::builder().method(Method::POST);
2101
2102            let mut request_builder = request_builder
2103                .header("Content-Type", "application/json")
2104                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2105
2106            // Only add Authorization header if we have a token
2107            if let Some(ref token_value) = token {
2108                request_builder =
2109                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2110            }
2111
2112            let request = build(request_builder)?;
2113
2114            let mut response = http_client.send(request).await?;
2115
2116            if let Some(minimum_required_version) = response
2117                .headers()
2118                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2119                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2120            {
2121                anyhow::ensure!(
2122                    app_version >= minimum_required_version,
2123                    ZedUpdateRequiredError {
2124                        minimum_version: minimum_required_version
2125                    }
2126                );
2127            }
2128
2129            if response.status().is_success() {
2130                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2131
2132                let mut body = Vec::new();
2133                response.body_mut().read_to_end(&mut body).await?;
2134                return Ok((serde_json::from_slice(&body)?, usage));
2135            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2136                did_retry = true;
2137                token = Some(llm_token.refresh(&client).await?);
2138            } else {
2139                let mut body = String::new();
2140                response.body_mut().read_to_string(&mut body).await?;
2141                anyhow::bail!(
2142                    "Request failed with status: {:?}\nBody: {}",
2143                    response.status(),
2144                    body
2145                );
2146            }
2147        }
2148    }
2149
2150    pub fn refresh_context(
2151        &mut self,
2152        project: &Entity<Project>,
2153        buffer: &Entity<language::Buffer>,
2154        cursor_position: language::Anchor,
2155        cx: &mut Context<Self>,
2156    ) {
2157        self.get_or_init_project(project, cx)
2158            .context
2159            .update(cx, |store, cx| {
2160                store.refresh(buffer.clone(), cursor_position, cx);
2161            });
2162    }
2163
2164    #[cfg(feature = "cli-support")]
2165    pub fn set_context_for_buffer(
2166        &mut self,
2167        project: &Entity<Project>,
2168        related_files: Vec<RelatedFile>,
2169        cx: &mut Context<Self>,
2170    ) {
2171        self.get_or_init_project(project, cx)
2172            .context
2173            .update(cx, |store, cx| {
2174                store.set_related_files(related_files, cx);
2175            });
2176    }
2177
2178    #[cfg(feature = "cli-support")]
2179    pub fn set_recent_paths_for_project(
2180        &mut self,
2181        project: &Entity<Project>,
2182        paths: impl IntoIterator<Item = project::ProjectPath>,
2183        cx: &mut Context<Self>,
2184    ) {
2185        let project_state = self.get_or_init_project(project, cx);
2186        project_state.recent_paths = paths.into_iter().collect();
2187    }
2188
2189    fn is_file_open_source(
2190        &self,
2191        project: &Entity<Project>,
2192        file: &Arc<dyn File>,
2193        cx: &App,
2194    ) -> bool {
2195        if !file.is_local() || file.is_private() {
2196            return false;
2197        }
2198        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2199            return false;
2200        };
2201        project_state
2202            .license_detection_watchers
2203            .get(&file.worktree_id(cx))
2204            .as_ref()
2205            .is_some_and(|watcher| watcher.is_project_open_source())
2206    }
2207
2208    fn load_data_collection_choice() -> DataCollectionChoice {
2209        let choice = KEY_VALUE_STORE
2210            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2211            .log_err()
2212            .flatten();
2213
2214        match choice.as_deref() {
2215            Some("true") => DataCollectionChoice::Enabled,
2216            Some("false") => DataCollectionChoice::Disabled,
2217            Some(_) => {
2218                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2219                DataCollectionChoice::NotAnswered
2220            }
2221            None => DataCollectionChoice::NotAnswered,
2222        }
2223    }
2224
2225    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2226        self.data_collection_choice = self.data_collection_choice.toggle();
2227        let new_choice = self.data_collection_choice;
2228        let is_enabled = new_choice.is_enabled(cx);
2229        db::write_and_log(cx, move || {
2230            KEY_VALUE_STORE.write_kvp(
2231                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2232                is_enabled.to_string(),
2233            )
2234        });
2235    }
2236
2237    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2238        self.shown_predictions.iter()
2239    }
2240
2241    pub fn shown_completions_len(&self) -> usize {
2242        self.shown_predictions.len()
2243    }
2244
2245    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2246        self.rated_predictions.contains(id)
2247    }
2248
2249    pub fn rate_prediction(
2250        &mut self,
2251        prediction: &EditPrediction,
2252        rating: EditPredictionRating,
2253        feedback: String,
2254        cx: &mut Context<Self>,
2255    ) {
2256        self.rated_predictions.insert(prediction.id.clone());
2257        telemetry::event!(
2258            "Edit Prediction Rated",
2259            request_id = prediction.id.to_string(),
2260            rating,
2261            inputs = prediction.inputs,
2262            output = prediction
2263                .edit_preview
2264                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2265            feedback
2266        );
2267        self.client.telemetry().flush_events().detach();
2268        cx.notify();
2269    }
2270}
2271
2272fn merge_trailing_events_if_needed(
2273    events: &mut VecDeque<StoredEvent>,
2274    end_snapshot: &TextBufferSnapshot,
2275    latest_snapshot: &TextBufferSnapshot,
2276    latest_edit_range: &Range<Anchor>,
2277) {
2278    let mut next_old_event = None;
2279    let mut mergeable_count = 0;
2280    for old_event in events.iter().rev() {
2281        if let Some(next_old_event) = &next_old_event
2282            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2283        {
2284            break;
2285        }
2286        mergeable_count += 1;
2287        next_old_event = Some(old_event);
2288    }
2289
2290    if mergeable_count <= 1 {
2291        return;
2292    }
2293
2294    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2295    let oldest_event = events_to_merge.peek().unwrap();
2296    let oldest_snapshot = oldest_event.old_snapshot.clone();
2297
2298    if let Some((diff, edited_range)) =
2299        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2300    {
2301        let merged_event = match oldest_event.event.as_ref() {
2302            zeta_prompt::Event::BufferChange {
2303                old_path,
2304                path,
2305                in_open_source_repo,
2306                ..
2307            } => StoredEvent {
2308                event: Arc::new(zeta_prompt::Event::BufferChange {
2309                    old_path: old_path.clone(),
2310                    path: path.clone(),
2311                    diff,
2312                    in_open_source_repo: *in_open_source_repo,
2313                    predicted: events_to_merge.all(|e| {
2314                        matches!(
2315                            e.event.as_ref(),
2316                            zeta_prompt::Event::BufferChange {
2317                                predicted: true,
2318                                ..
2319                            }
2320                        )
2321                    }),
2322                }),
2323                old_snapshot: oldest_snapshot.clone(),
2324                edit_range: end_snapshot.anchor_before(edited_range.start)
2325                    ..end_snapshot.anchor_before(edited_range.end),
2326            },
2327        };
2328        events.truncate(events.len() - mergeable_count);
2329        events.push_back(merged_event);
2330    }
2331}
2332
2333pub(crate) fn filter_redundant_excerpts(
2334    mut related_files: Vec<RelatedFile>,
2335    cursor_path: &Path,
2336    cursor_row_range: Range<u32>,
2337) -> Vec<RelatedFile> {
2338    for file in &mut related_files {
2339        if file.path.as_ref() == cursor_path {
2340            file.excerpts.retain(|excerpt| {
2341                excerpt.row_range.start < cursor_row_range.start
2342                    || excerpt.row_range.end > cursor_row_range.end
2343            });
2344        }
2345    }
2346    related_files.retain(|file| !file.excerpts.is_empty());
2347    related_files
2348}
2349
2350#[derive(Error, Debug)]
2351#[error(
2352    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2353)]
2354pub struct ZedUpdateRequiredError {
2355    minimum_version: Version,
2356}
2357
2358#[derive(Debug, Clone, Copy)]
2359pub enum DataCollectionChoice {
2360    NotAnswered,
2361    Enabled,
2362    Disabled,
2363}
2364
2365impl DataCollectionChoice {
2366    pub fn is_enabled(self, cx: &App) -> bool {
2367        if cx.is_staff() {
2368            return true;
2369        }
2370        match self {
2371            Self::Enabled => true,
2372            Self::NotAnswered | Self::Disabled => false,
2373        }
2374    }
2375
2376    #[must_use]
2377    pub fn toggle(&self) -> DataCollectionChoice {
2378        match self {
2379            Self::Enabled => Self::Disabled,
2380            Self::Disabled => Self::Enabled,
2381            Self::NotAnswered => Self::Enabled,
2382        }
2383    }
2384}
2385
2386impl From<bool> for DataCollectionChoice {
2387    fn from(value: bool) -> Self {
2388        match value {
2389            true => DataCollectionChoice::Enabled,
2390            false => DataCollectionChoice::Disabled,
2391        }
2392    }
2393}
2394
2395struct ZedPredictUpsell;
2396
2397impl Dismissable for ZedPredictUpsell {
2398    const KEY: &'static str = "dismissed-edit-predict-upsell";
2399
2400    fn dismissed() -> bool {
2401        // To make this backwards compatible with older versions of Zed, we
2402        // check if the user has seen the previous Edit Prediction Onboarding
2403        // before, by checking the data collection choice which was written to
2404        // the database once the user clicked on "Accept and Enable"
2405        if KEY_VALUE_STORE
2406            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2407            .log_err()
2408            .is_some_and(|s| s.is_some())
2409        {
2410            return true;
2411        }
2412
2413        KEY_VALUE_STORE
2414            .read_kvp(Self::KEY)
2415            .log_err()
2416            .is_some_and(|s| s.is_some())
2417    }
2418}
2419
2420pub fn should_show_upsell_modal() -> bool {
2421    !ZedPredictUpsell::dismissed()
2422}
2423
2424pub fn init(cx: &mut App) {
2425    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2426        workspace.register_action(
2427            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2428                ZedPredictModal::toggle(
2429                    workspace,
2430                    workspace.user_store().clone(),
2431                    workspace.client().clone(),
2432                    window,
2433                    cx,
2434                )
2435            },
2436        );
2437
2438        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2439            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2440                settings
2441                    .project
2442                    .all_languages
2443                    .edit_predictions
2444                    .get_or_insert_default()
2445                    .provider = Some(EditPredictionProvider::None)
2446            });
2447        });
2448        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2449            EditPredictionStore::try_global(cx).and_then(|store| {
2450                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2451            })
2452        }
2453
2454        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2455            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2456                copilot_ui::initiate_sign_in(copilot, window, cx);
2457            }
2458        });
2459        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2460            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2461                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2462            }
2463        });
2464        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2465            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2466                copilot_ui::initiate_sign_out(copilot, window, cx);
2467            }
2468        });
2469    })
2470    .detach();
2471}