edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::Edit;
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::onboarding_modal::ZedPredictModal;
  81pub use crate::prediction::EditPrediction;
  82pub use crate::prediction::EditPredictionId;
  83use crate::prediction::EditPredictionResult;
  84pub use crate::sweep_ai::SweepAi;
  85pub use capture_example::capture_example;
  86pub use language_model::ApiKeyState;
  87pub use telemetry_events::EditPredictionRating;
  88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  89
  90actions!(
  91    edit_prediction,
  92    [
  93        /// Resets the edit prediction onboarding state.
  94        ResetOnboarding,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106
 107pub struct Zeta2FeatureFlag;
 108
 109impl FeatureFlag for Zeta2FeatureFlag {
 110    const NAME: &'static str = "zeta2";
 111
 112    fn enabled_for_staff() -> bool {
 113        true
 114    }
 115}
 116
 117#[derive(Clone)]
 118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 119
 120impl Global for EditPredictionStoreGlobal {}
 121
 122/// Configuration for using the raw Zeta2 endpoint.
 123/// When set, the client uses the raw endpoint and constructs the prompt itself.
 124/// The version is also used as the Baseten environment name (lowercased).
 125#[derive(Clone)]
 126pub struct Zeta2RawConfig {
 127    pub model_id: Option<String>,
 128    pub format: ZetaFormat,
 129}
 130
 131pub struct EditPredictionStore {
 132    client: Arc<Client>,
 133    user_store: Entity<UserStore>,
 134    llm_token: LlmApiToken,
 135    _llm_token_subscription: Subscription,
 136    projects: HashMap<EntityId, ProjectState>,
 137    update_required: bool,
 138    edit_prediction_model: EditPredictionModel,
 139    zeta2_raw_config: Option<Zeta2RawConfig>,
 140    pub sweep_ai: SweepAi,
 141    pub mercury: Mercury,
 142    data_collection_choice: DataCollectionChoice,
 143    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 144    shown_predictions: VecDeque<EditPrediction>,
 145    rated_predictions: HashSet<EditPredictionId>,
 146}
 147
 148#[derive(Copy, Clone, PartialEq, Eq)]
 149pub enum EditPredictionModel {
 150    Zeta1,
 151    Zeta2,
 152    Fim { format: EditPredictionPromptFormat },
 153    Sweep,
 154    Mercury,
 155}
 156
 157#[derive(Clone)]
 158pub struct EditPredictionModelInput {
 159    project: Entity<Project>,
 160    buffer: Entity<Buffer>,
 161    snapshot: BufferSnapshot,
 162    position: Anchor,
 163    events: Vec<Arc<zeta_prompt::Event>>,
 164    related_files: Vec<RelatedFile>,
 165    recent_paths: VecDeque<ProjectPath>,
 166    trigger: PredictEditsRequestTrigger,
 167    diagnostic_search_range: Range<Point>,
 168    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 169    pub user_actions: Vec<UserActionRecord>,
 170}
 171
 172#[derive(Debug)]
 173pub enum DebugEvent {
 174    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 175    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 176    EditPredictionStarted(EditPredictionStartedDebugEvent),
 177    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 178}
 179
 180#[derive(Debug)]
 181pub struct ContextRetrievalStartedDebugEvent {
 182    pub project_entity_id: EntityId,
 183    pub timestamp: Instant,
 184    pub search_prompt: String,
 185}
 186
 187#[derive(Debug)]
 188pub struct ContextRetrievalFinishedDebugEvent {
 189    pub project_entity_id: EntityId,
 190    pub timestamp: Instant,
 191    pub metadata: Vec<(&'static str, SharedString)>,
 192}
 193
 194#[derive(Debug)]
 195pub struct EditPredictionStartedDebugEvent {
 196    pub buffer: WeakEntity<Buffer>,
 197    pub position: Anchor,
 198    pub prompt: Option<String>,
 199}
 200
 201#[derive(Debug)]
 202pub struct EditPredictionFinishedDebugEvent {
 203    pub buffer: WeakEntity<Buffer>,
 204    pub position: Anchor,
 205    pub model_output: Option<String>,
 206}
 207
 208const USER_ACTION_HISTORY_SIZE: usize = 16;
 209
 210#[derive(Clone, Debug)]
 211pub struct UserActionRecord {
 212    pub action_type: UserActionType,
 213    pub buffer_id: EntityId,
 214    pub line_number: u32,
 215    pub offset: usize,
 216    pub timestamp_epoch_ms: u64,
 217}
 218
 219#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 220pub enum UserActionType {
 221    InsertChar,
 222    InsertSelection,
 223    DeleteChar,
 224    DeleteSelection,
 225    CursorMovement,
 226}
 227
 228/// An event with associated metadata for reconstructing buffer state.
 229#[derive(Clone)]
 230pub struct StoredEvent {
 231    pub event: Arc<zeta_prompt::Event>,
 232    pub old_snapshot: TextBufferSnapshot,
 233    pub edit_range: Range<Anchor>,
 234}
 235
 236impl StoredEvent {
 237    fn can_merge(
 238        &self,
 239        next_old_event: &&&StoredEvent,
 240        new_snapshot: &TextBufferSnapshot,
 241        last_edit_range: &Range<Anchor>,
 242    ) -> bool {
 243        // Events must be for the same buffer
 244        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 245            return false;
 246        }
 247        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 248            return false;
 249        }
 250
 251        let a_is_predicted = matches!(
 252            self.event.as_ref(),
 253            zeta_prompt::Event::BufferChange {
 254                predicted: true,
 255                ..
 256            }
 257        );
 258        let b_is_predicted = matches!(
 259            next_old_event.event.as_ref(),
 260            zeta_prompt::Event::BufferChange {
 261                predicted: true,
 262                ..
 263            }
 264        );
 265
 266        // If events come from the same source (both predicted or both manual) then
 267        // we would have coalesced them already.
 268        if a_is_predicted == b_is_predicted {
 269            return false;
 270        }
 271
 272        let left_range = self.edit_range.to_point(new_snapshot);
 273        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 274        let latest_range = last_edit_range.to_point(&new_snapshot);
 275
 276        // Events near to the latest edit are not merged if their sources differ.
 277        if lines_between_ranges(&left_range, &latest_range)
 278            .min(lines_between_ranges(&right_range, &latest_range))
 279            <= CHANGE_GROUPING_LINE_SPAN
 280        {
 281            return false;
 282        }
 283
 284        // Events that are distant from each other are not merged.
 285        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 286            return false;
 287        }
 288
 289        true
 290    }
 291}
 292
 293fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 294    if left.start > right.end {
 295        return left.start.row - right.end.row;
 296    }
 297    if right.start > left.end {
 298        return right.start.row - left.end.row;
 299    }
 300    0
 301}
 302
 303struct ProjectState {
 304    events: VecDeque<StoredEvent>,
 305    last_event: Option<LastEvent>,
 306    recent_paths: VecDeque<ProjectPath>,
 307    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 308    current_prediction: Option<CurrentEditPrediction>,
 309    next_pending_prediction_id: usize,
 310    pending_predictions: ArrayVec<PendingPrediction, 2>,
 311    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 312    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 313    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 314    cancelled_predictions: HashSet<usize>,
 315    context: Entity<RelatedExcerptStore>,
 316    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 317    user_actions: VecDeque<UserActionRecord>,
 318    _subscriptions: [gpui::Subscription; 2],
 319    copilot: Option<Entity<Copilot>>,
 320}
 321
 322impl ProjectState {
 323    fn record_user_action(&mut self, action: UserActionRecord) {
 324        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 325            self.user_actions.pop_front();
 326        }
 327        self.user_actions.push_back(action);
 328    }
 329
 330    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 331        self.events
 332            .iter()
 333            .cloned()
 334            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 335                let (one, two) = event.split_by_pause();
 336                let one = one.finalize(&self.license_detection_watchers, cx);
 337                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 338                one.into_iter().chain(two)
 339            }))
 340            .collect()
 341    }
 342
 343    fn cancel_pending_prediction(
 344        &mut self,
 345        pending_prediction: PendingPrediction,
 346        cx: &mut Context<EditPredictionStore>,
 347    ) {
 348        self.cancelled_predictions.insert(pending_prediction.id);
 349
 350        if pending_prediction.drop_on_cancel {
 351            drop(pending_prediction.task);
 352        } else {
 353            cx.spawn(async move |this, cx| {
 354                let Some(prediction_id) = pending_prediction.task.await else {
 355                    return;
 356                };
 357
 358                this.update(cx, |this, cx| {
 359                    this.reject_prediction(
 360                        prediction_id,
 361                        EditPredictionRejectReason::Canceled,
 362                        false,
 363                        cx,
 364                    );
 365                })
 366                .ok();
 367            })
 368            .detach()
 369        }
 370    }
 371
 372    fn active_buffer(
 373        &self,
 374        project: &Entity<Project>,
 375        cx: &App,
 376    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 377        let project = project.read(cx);
 378        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 379        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 380        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 381        Some((active_buffer, registered_buffer.last_position))
 382    }
 383}
 384
 385#[derive(Debug, Clone)]
 386struct CurrentEditPrediction {
 387    pub requested_by: PredictionRequestedBy,
 388    pub prediction: EditPrediction,
 389    pub was_shown: bool,
 390    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 391}
 392
 393impl CurrentEditPrediction {
 394    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 395        let Some(new_edits) = self
 396            .prediction
 397            .interpolate(&self.prediction.buffer.read(cx))
 398        else {
 399            return false;
 400        };
 401
 402        if self.prediction.buffer != old_prediction.prediction.buffer {
 403            return true;
 404        }
 405
 406        let Some(old_edits) = old_prediction
 407            .prediction
 408            .interpolate(&old_prediction.prediction.buffer.read(cx))
 409        else {
 410            return true;
 411        };
 412
 413        let requested_by_buffer_id = self.requested_by.buffer_id();
 414
 415        // This reduces the occurrence of UI thrash from replacing edits
 416        //
 417        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 418        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 419            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 420            && old_edits.len() == 1
 421            && new_edits.len() == 1
 422        {
 423            let (old_range, old_text) = &old_edits[0];
 424            let (new_range, new_text) = &new_edits[0];
 425            new_range == old_range && new_text.starts_with(old_text.as_ref())
 426        } else {
 427            true
 428        }
 429    }
 430}
 431
 432#[derive(Debug, Clone)]
 433enum PredictionRequestedBy {
 434    DiagnosticsUpdate,
 435    Buffer(EntityId),
 436}
 437
 438impl PredictionRequestedBy {
 439    pub fn buffer_id(&self) -> Option<EntityId> {
 440        match self {
 441            PredictionRequestedBy::DiagnosticsUpdate => None,
 442            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 443        }
 444    }
 445}
 446
 447const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 448
 449#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 450pub enum DiagnosticSearchScope {
 451    Local,
 452    Global,
 453}
 454
 455#[derive(Debug)]
 456struct PendingPrediction {
 457    id: usize,
 458    task: Task<Option<EditPredictionId>>,
 459    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 460    /// If false, the task is awaited to completion so rejection can be reported.
 461    drop_on_cancel: bool,
 462}
 463
 464/// A prediction from the perspective of a buffer.
 465#[derive(Debug)]
 466enum BufferEditPrediction<'a> {
 467    Local { prediction: &'a EditPrediction },
 468    Jump { prediction: &'a EditPrediction },
 469}
 470
 471#[cfg(test)]
 472impl std::ops::Deref for BufferEditPrediction<'_> {
 473    type Target = EditPrediction;
 474
 475    fn deref(&self) -> &Self::Target {
 476        match self {
 477            BufferEditPrediction::Local { prediction } => prediction,
 478            BufferEditPrediction::Jump { prediction } => prediction,
 479        }
 480    }
 481}
 482
 483struct RegisteredBuffer {
 484    file: Option<Arc<dyn File>>,
 485    snapshot: TextBufferSnapshot,
 486    last_position: Option<Anchor>,
 487    _subscriptions: [gpui::Subscription; 2],
 488}
 489
 490#[derive(Clone)]
 491struct LastEvent {
 492    old_snapshot: TextBufferSnapshot,
 493    new_snapshot: TextBufferSnapshot,
 494    old_file: Option<Arc<dyn File>>,
 495    new_file: Option<Arc<dyn File>>,
 496    edit_range: Option<Range<Anchor>>,
 497    predicted: bool,
 498    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 499    last_edit_time: Option<Instant>,
 500}
 501
 502impl LastEvent {
 503    pub fn finalize(
 504        &self,
 505        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 506        cx: &App,
 507    ) -> Option<StoredEvent> {
 508        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 509        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 510
 511        let in_open_source_repo =
 512            [self.new_file.as_ref(), self.old_file.as_ref()]
 513                .iter()
 514                .all(|file| {
 515                    file.is_some_and(|file| {
 516                        license_detection_watchers
 517                            .get(&file.worktree_id(cx))
 518                            .is_some_and(|watcher| watcher.is_project_open_source())
 519                    })
 520                });
 521
 522        let (diff, edit_range) =
 523            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 524
 525        if path == old_path && diff.is_empty() {
 526            None
 527        } else {
 528            Some(StoredEvent {
 529                event: Arc::new(zeta_prompt::Event::BufferChange {
 530                    old_path,
 531                    path,
 532                    diff,
 533                    in_open_source_repo,
 534                    predicted: self.predicted,
 535                }),
 536                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 537                    ..self.new_snapshot.anchor_before(edit_range.end),
 538                old_snapshot: self.old_snapshot.clone(),
 539            })
 540        }
 541    }
 542
 543    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 544        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 545            return (self.clone(), None);
 546        };
 547
 548        let before = LastEvent {
 549            old_snapshot: self.old_snapshot.clone(),
 550            new_snapshot: boundary_snapshot.clone(),
 551            old_file: self.old_file.clone(),
 552            new_file: self.new_file.clone(),
 553            edit_range: None,
 554            predicted: self.predicted,
 555            snapshot_after_last_editing_pause: None,
 556            last_edit_time: self.last_edit_time,
 557        };
 558
 559        let after = LastEvent {
 560            old_snapshot: boundary_snapshot.clone(),
 561            new_snapshot: self.new_snapshot.clone(),
 562            old_file: self.old_file.clone(),
 563            new_file: self.new_file.clone(),
 564            edit_range: None,
 565            predicted: self.predicted,
 566            snapshot_after_last_editing_pause: None,
 567            last_edit_time: self.last_edit_time,
 568        };
 569
 570        (before, Some(after))
 571    }
 572}
 573
 574pub(crate) fn compute_diff_between_snapshots(
 575    old_snapshot: &TextBufferSnapshot,
 576    new_snapshot: &TextBufferSnapshot,
 577) -> Option<(String, Range<Point>)> {
 578    let edits: Vec<Edit<usize>> = new_snapshot
 579        .edits_since::<usize>(&old_snapshot.version)
 580        .collect();
 581
 582    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 583
 584    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 585    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 586    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 587    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 588
 589    const CONTEXT_LINES: u32 = 3;
 590
 591    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 592    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 593    let old_context_end_row =
 594        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 595    let new_context_end_row =
 596        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 597
 598    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 599    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 600    let old_end_line_offset = old_snapshot
 601        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 602    let new_end_line_offset = new_snapshot
 603        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 604    let old_edit_range = old_start_line_offset..old_end_line_offset;
 605    let new_edit_range = new_start_line_offset..new_end_line_offset;
 606
 607    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 608    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 609
 610    let diff = language::unified_diff_with_offsets(
 611        &old_region_text,
 612        &new_region_text,
 613        old_context_start_row,
 614        new_context_start_row,
 615    );
 616
 617    Some((diff, new_start_point..new_end_point))
 618}
 619
 620fn buffer_path_with_id_fallback(
 621    file: Option<&Arc<dyn File>>,
 622    snapshot: &TextBufferSnapshot,
 623    cx: &App,
 624) -> Arc<Path> {
 625    if let Some(file) = file {
 626        file.full_path(cx).into()
 627    } else {
 628        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 629    }
 630}
 631
 632impl EditPredictionStore {
 633    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 634        cx.try_global::<EditPredictionStoreGlobal>()
 635            .map(|global| global.0.clone())
 636    }
 637
 638    pub fn global(
 639        client: &Arc<Client>,
 640        user_store: &Entity<UserStore>,
 641        cx: &mut App,
 642    ) -> Entity<Self> {
 643        cx.try_global::<EditPredictionStoreGlobal>()
 644            .map(|global| global.0.clone())
 645            .unwrap_or_else(|| {
 646                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 647                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 648                ep_store
 649            })
 650    }
 651
 652    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 653        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 654        let data_collection_choice = Self::load_data_collection_choice();
 655
 656        let llm_token = LlmApiToken::default();
 657
 658        let (reject_tx, reject_rx) = mpsc::unbounded();
 659        cx.background_spawn({
 660            let client = client.clone();
 661            let llm_token = llm_token.clone();
 662            let app_version = AppVersion::global(cx);
 663            let background_executor = cx.background_executor().clone();
 664            async move {
 665                Self::handle_rejected_predictions(
 666                    reject_rx,
 667                    client,
 668                    llm_token,
 669                    app_version,
 670                    background_executor,
 671                )
 672                .await
 673            }
 674        })
 675        .detach();
 676
 677        let this = Self {
 678            projects: HashMap::default(),
 679            client,
 680            user_store,
 681            llm_token,
 682            _llm_token_subscription: cx.subscribe(
 683                &refresh_llm_token_listener,
 684                |this, _listener, _event, cx| {
 685                    let client = this.client.clone();
 686                    let llm_token = this.llm_token.clone();
 687                    cx.spawn(async move |_this, _cx| {
 688                        llm_token.refresh(&client).await?;
 689                        anyhow::Ok(())
 690                    })
 691                    .detach_and_log_err(cx);
 692                },
 693            ),
 694            update_required: false,
 695            edit_prediction_model: EditPredictionModel::Zeta2,
 696            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 697            sweep_ai: SweepAi::new(cx),
 698            mercury: Mercury::new(cx),
 699
 700            data_collection_choice,
 701            reject_predictions_tx: reject_tx,
 702            rated_predictions: Default::default(),
 703            shown_predictions: Default::default(),
 704        };
 705
 706        this
 707    }
 708
 709    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 710        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 711        let format = ZetaFormat::parse(&version_str).ok()?;
 712        let model_id = env::var("ZED_ZETA_MODEL").ok();
 713        Some(Zeta2RawConfig { model_id, format })
 714    }
 715
 716    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 717        self.edit_prediction_model = model;
 718    }
 719
 720    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 721        self.zeta2_raw_config = Some(config);
 722    }
 723
 724    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 725        self.zeta2_raw_config.as_ref()
 726    }
 727
 728    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 729        use ui::IconName;
 730        match self.edit_prediction_model {
 731            EditPredictionModel::Sweep => {
 732                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 733                    .with_disabled(IconName::SweepAiDisabled)
 734                    .with_up(IconName::SweepAiUp)
 735                    .with_down(IconName::SweepAiDown)
 736                    .with_error(IconName::SweepAiError)
 737            }
 738            EditPredictionModel::Mercury => {
 739                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 740            }
 741            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 742                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 743                    .with_disabled(IconName::ZedPredictDisabled)
 744                    .with_up(IconName::ZedPredictUp)
 745                    .with_down(IconName::ZedPredictDown)
 746                    .with_error(IconName::ZedPredictError)
 747            }
 748            EditPredictionModel::Fim { .. } => {
 749                let settings = &all_language_settings(None, cx).edit_predictions;
 750                match settings.provider {
 751                    EditPredictionProvider::Ollama => {
 752                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 753                    }
 754                    _ => {
 755                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 756                    }
 757                }
 758            }
 759        }
 760    }
 761
 762    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 763        self.sweep_ai.api_token.read(cx).has_key()
 764    }
 765
 766    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 767        self.mercury.api_token.read(cx).has_key()
 768    }
 769
 770    pub fn clear_history(&mut self) {
 771        for project_state in self.projects.values_mut() {
 772            project_state.events.clear();
 773            project_state.last_event.take();
 774        }
 775    }
 776
 777    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 778        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 779            project_state.events.clear();
 780            project_state.last_event.take();
 781        }
 782    }
 783
 784    pub fn edit_history_for_project(
 785        &self,
 786        project: &Entity<Project>,
 787        cx: &App,
 788    ) -> Vec<StoredEvent> {
 789        self.projects
 790            .get(&project.entity_id())
 791            .map(|project_state| project_state.events(cx))
 792            .unwrap_or_default()
 793    }
 794
 795    pub fn context_for_project<'a>(
 796        &'a self,
 797        project: &Entity<Project>,
 798        cx: &'a mut App,
 799    ) -> Vec<RelatedFile> {
 800        self.projects
 801            .get(&project.entity_id())
 802            .map(|project_state| {
 803                project_state.context.update(cx, |context, cx| {
 804                    context
 805                        .related_files_with_buffers(cx)
 806                        .map(|(mut related_file, buffer)| {
 807                            related_file.in_open_source_repo = buffer
 808                                .read(cx)
 809                                .file()
 810                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 811                            related_file
 812                        })
 813                        .collect()
 814                })
 815            })
 816            .unwrap_or_default()
 817    }
 818
 819    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 820        self.projects
 821            .get(&project.entity_id())
 822            .and_then(|project| project.copilot.clone())
 823    }
 824
 825    pub fn start_copilot_for_project(
 826        &mut self,
 827        project: &Entity<Project>,
 828        cx: &mut Context<Self>,
 829    ) -> Option<Entity<Copilot>> {
 830        if DisableAiSettings::get(None, cx).disable_ai {
 831            return None;
 832        }
 833        let state = self.get_or_init_project(project, cx);
 834
 835        if state.copilot.is_some() {
 836            return state.copilot.clone();
 837        }
 838        let _project = project.clone();
 839        let project = project.read(cx);
 840
 841        let node = project.node_runtime().cloned();
 842        if let Some(node) = node {
 843            let next_id = project.languages().next_language_server_id();
 844            let fs = project.fs().clone();
 845
 846            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 847            state.copilot = Some(copilot.clone());
 848            Some(copilot)
 849        } else {
 850            None
 851        }
 852    }
 853
 854    pub fn context_for_project_with_buffers<'a>(
 855        &'a self,
 856        project: &Entity<Project>,
 857        cx: &'a mut App,
 858    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 859        self.projects
 860            .get(&project.entity_id())
 861            .map(|project| {
 862                project.context.update(cx, |context, cx| {
 863                    context.related_files_with_buffers(cx).collect()
 864                })
 865            })
 866            .unwrap_or_default()
 867    }
 868
 869    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 870        if matches!(
 871            self.edit_prediction_model,
 872            EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
 873        ) {
 874            self.user_store.read(cx).edit_prediction_usage()
 875        } else {
 876            None
 877        }
 878    }
 879
 880    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 881        self.get_or_init_project(project, cx);
 882    }
 883
 884    pub fn register_buffer(
 885        &mut self,
 886        buffer: &Entity<Buffer>,
 887        project: &Entity<Project>,
 888        cx: &mut Context<Self>,
 889    ) {
 890        let project_state = self.get_or_init_project(project, cx);
 891        Self::register_buffer_impl(project_state, buffer, project, cx);
 892    }
 893
 894    fn get_or_init_project(
 895        &mut self,
 896        project: &Entity<Project>,
 897        cx: &mut Context<Self>,
 898    ) -> &mut ProjectState {
 899        let entity_id = project.entity_id();
 900        self.projects
 901            .entry(entity_id)
 902            .or_insert_with(|| ProjectState {
 903                context: {
 904                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 905                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 906                        this.handle_excerpt_store_event(entity_id, event);
 907                    })
 908                    .detach();
 909                    related_excerpt_store
 910                },
 911                events: VecDeque::new(),
 912                last_event: None,
 913                recent_paths: VecDeque::new(),
 914                debug_tx: None,
 915                registered_buffers: HashMap::default(),
 916                current_prediction: None,
 917                cancelled_predictions: HashSet::default(),
 918                pending_predictions: ArrayVec::new(),
 919                next_pending_prediction_id: 0,
 920                last_edit_prediction_refresh: None,
 921                last_jump_prediction_refresh: None,
 922                license_detection_watchers: HashMap::default(),
 923                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 924                _subscriptions: [
 925                    cx.subscribe(&project, Self::handle_project_event),
 926                    cx.observe_release(&project, move |this, _, cx| {
 927                        this.projects.remove(&entity_id);
 928                        cx.notify();
 929                    }),
 930                ],
 931                copilot: None,
 932            })
 933    }
 934
 935    pub fn remove_project(&mut self, project: &Entity<Project>) {
 936        self.projects.remove(&project.entity_id());
 937    }
 938
 939    fn handle_excerpt_store_event(
 940        &mut self,
 941        project_entity_id: EntityId,
 942        event: &RelatedExcerptStoreEvent,
 943    ) {
 944        if let Some(project_state) = self.projects.get(&project_entity_id) {
 945            if let Some(debug_tx) = project_state.debug_tx.clone() {
 946                match event {
 947                    RelatedExcerptStoreEvent::StartedRefresh => {
 948                        debug_tx
 949                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 950                                ContextRetrievalStartedDebugEvent {
 951                                    project_entity_id: project_entity_id,
 952                                    timestamp: Instant::now(),
 953                                    search_prompt: String::new(),
 954                                },
 955                            ))
 956                            .ok();
 957                    }
 958                    RelatedExcerptStoreEvent::FinishedRefresh {
 959                        cache_hit_count,
 960                        cache_miss_count,
 961                        mean_definition_latency,
 962                        max_definition_latency,
 963                    } => {
 964                        debug_tx
 965                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 966                                ContextRetrievalFinishedDebugEvent {
 967                                    project_entity_id: project_entity_id,
 968                                    timestamp: Instant::now(),
 969                                    metadata: vec![
 970                                        (
 971                                            "Cache Hits",
 972                                            format!(
 973                                                "{}/{}",
 974                                                cache_hit_count,
 975                                                cache_hit_count + cache_miss_count
 976                                            )
 977                                            .into(),
 978                                        ),
 979                                        (
 980                                            "Max LSP Time",
 981                                            format!("{} ms", max_definition_latency.as_millis())
 982                                                .into(),
 983                                        ),
 984                                        (
 985                                            "Mean LSP Time",
 986                                            format!("{} ms", mean_definition_latency.as_millis())
 987                                                .into(),
 988                                        ),
 989                                    ],
 990                                },
 991                            ))
 992                            .ok();
 993                    }
 994                }
 995            }
 996        }
 997    }
 998
 999    pub fn debug_info(
1000        &mut self,
1001        project: &Entity<Project>,
1002        cx: &mut Context<Self>,
1003    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1004        let project_state = self.get_or_init_project(project, cx);
1005        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1006        project_state.debug_tx = Some(debug_watch_tx);
1007        debug_watch_rx
1008    }
1009
1010    fn handle_project_event(
1011        &mut self,
1012        project: Entity<Project>,
1013        event: &project::Event,
1014        cx: &mut Context<Self>,
1015    ) {
1016        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1017            return;
1018        }
1019        // TODO [zeta2] init with recent paths
1020        match event {
1021            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1022                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1023                    return;
1024                };
1025                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1026                if let Some(path) = path {
1027                    if let Some(ix) = project_state
1028                        .recent_paths
1029                        .iter()
1030                        .position(|probe| probe == &path)
1031                    {
1032                        project_state.recent_paths.remove(ix);
1033                    }
1034                    project_state.recent_paths.push_front(path);
1035                }
1036            }
1037            project::Event::DiagnosticsUpdated { .. } => {
1038                if cx.has_flag::<Zeta2FeatureFlag>() {
1039                    self.refresh_prediction_from_diagnostics(
1040                        project,
1041                        DiagnosticSearchScope::Global,
1042                        cx,
1043                    );
1044                }
1045            }
1046            _ => (),
1047        }
1048    }
1049
1050    fn register_buffer_impl<'a>(
1051        project_state: &'a mut ProjectState,
1052        buffer: &Entity<Buffer>,
1053        project: &Entity<Project>,
1054        cx: &mut Context<Self>,
1055    ) -> &'a mut RegisteredBuffer {
1056        let buffer_id = buffer.entity_id();
1057
1058        if let Some(file) = buffer.read(cx).file() {
1059            let worktree_id = file.worktree_id(cx);
1060            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1061                project_state
1062                    .license_detection_watchers
1063                    .entry(worktree_id)
1064                    .or_insert_with(|| {
1065                        let project_entity_id = project.entity_id();
1066                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1067                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1068                            else {
1069                                return;
1070                            };
1071                            project_state
1072                                .license_detection_watchers
1073                                .remove(&worktree_id);
1074                        })
1075                        .detach();
1076                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1077                    });
1078            }
1079        }
1080
1081        match project_state.registered_buffers.entry(buffer_id) {
1082            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1083            hash_map::Entry::Vacant(entry) => {
1084                let buf = buffer.read(cx);
1085                let snapshot = buf.text_snapshot();
1086                let file = buf.file().cloned();
1087                let project_entity_id = project.entity_id();
1088                entry.insert(RegisteredBuffer {
1089                    snapshot,
1090                    file,
1091                    last_position: None,
1092                    _subscriptions: [
1093                        cx.subscribe(buffer, {
1094                            let project = project.downgrade();
1095                            move |this, buffer, event, cx| {
1096                                if let language::BufferEvent::Edited = event
1097                                    && let Some(project) = project.upgrade()
1098                                {
1099                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1100                                }
1101                            }
1102                        }),
1103                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1104                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1105                            else {
1106                                return;
1107                            };
1108                            project_state.registered_buffers.remove(&buffer_id);
1109                        }),
1110                    ],
1111                })
1112            }
1113        }
1114    }
1115
1116    fn report_changes_for_buffer(
1117        &mut self,
1118        buffer: &Entity<Buffer>,
1119        project: &Entity<Project>,
1120        is_predicted: bool,
1121        cx: &mut Context<Self>,
1122    ) {
1123        let project_state = self.get_or_init_project(project, cx);
1124        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1125
1126        let buf = buffer.read(cx);
1127        let new_file = buf.file().cloned();
1128        let new_snapshot = buf.text_snapshot();
1129        if new_snapshot.version == registered_buffer.snapshot.version {
1130            return;
1131        }
1132
1133        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1134        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1135        let mut num_edits = 0usize;
1136        let mut total_deleted = 0usize;
1137        let mut total_inserted = 0usize;
1138        let mut edit_range: Option<Range<Anchor>> = None;
1139        let mut last_offset: Option<usize> = None;
1140
1141        for (edit, anchor_range) in
1142            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1143        {
1144            num_edits += 1;
1145            total_deleted += edit.old.len();
1146            total_inserted += edit.new.len();
1147            edit_range = Some(match edit_range {
1148                None => anchor_range,
1149                Some(acc) => acc.start..anchor_range.end,
1150            });
1151            last_offset = Some(edit.new.end);
1152        }
1153
1154        let Some(edit_range) = edit_range else {
1155            return;
1156        };
1157
1158        let action_type = match (total_deleted, total_inserted, num_edits) {
1159            (0, ins, n) if ins == n => UserActionType::InsertChar,
1160            (0, _, _) => UserActionType::InsertSelection,
1161            (del, 0, n) if del == n => UserActionType::DeleteChar,
1162            (_, 0, _) => UserActionType::DeleteSelection,
1163            (_, ins, n) if ins == n => UserActionType::InsertChar,
1164            (_, _, _) => UserActionType::InsertSelection,
1165        };
1166
1167        if let Some(offset) = last_offset {
1168            let point = new_snapshot.offset_to_point(offset);
1169            let timestamp_epoch_ms = SystemTime::now()
1170                .duration_since(UNIX_EPOCH)
1171                .map(|d| d.as_millis() as u64)
1172                .unwrap_or(0);
1173            project_state.record_user_action(UserActionRecord {
1174                action_type,
1175                buffer_id: buffer.entity_id(),
1176                line_number: point.row,
1177                offset,
1178                timestamp_epoch_ms,
1179            });
1180        }
1181
1182        let events = &mut project_state.events;
1183
1184        let now = cx.background_executor().now();
1185        if let Some(last_event) = project_state.last_event.as_mut() {
1186            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1187                == last_event.new_snapshot.remote_id()
1188                && old_snapshot.version == last_event.new_snapshot.version;
1189
1190            let prediction_source_changed = is_predicted != last_event.predicted;
1191
1192            let should_coalesce = is_next_snapshot_of_same_buffer
1193                && !prediction_source_changed
1194                && last_event
1195                    .edit_range
1196                    .as_ref()
1197                    .is_some_and(|last_edit_range| {
1198                        lines_between_ranges(
1199                            &edit_range.to_point(&new_snapshot),
1200                            &last_edit_range.to_point(&new_snapshot),
1201                        ) <= CHANGE_GROUPING_LINE_SPAN
1202                    });
1203
1204            if should_coalesce {
1205                let pause_elapsed = last_event
1206                    .last_edit_time
1207                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1208                    .unwrap_or(false);
1209                if pause_elapsed {
1210                    last_event.snapshot_after_last_editing_pause =
1211                        Some(last_event.new_snapshot.clone());
1212                }
1213
1214                last_event.edit_range = Some(edit_range);
1215                last_event.new_snapshot = new_snapshot;
1216                last_event.last_edit_time = Some(now);
1217                return;
1218            }
1219        }
1220
1221        if let Some(event) = project_state.last_event.take() {
1222            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1223                if events.len() + 1 >= EVENT_COUNT_MAX {
1224                    events.pop_front();
1225                }
1226                events.push_back(event);
1227            }
1228        }
1229
1230        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1231
1232        project_state.last_event = Some(LastEvent {
1233            old_file,
1234            new_file,
1235            old_snapshot,
1236            new_snapshot,
1237            edit_range: Some(edit_range),
1238            predicted: is_predicted,
1239            snapshot_after_last_editing_pause: None,
1240            last_edit_time: Some(now),
1241        });
1242    }
1243
1244    fn prediction_at(
1245        &mut self,
1246        buffer: &Entity<Buffer>,
1247        position: Option<language::Anchor>,
1248        project: &Entity<Project>,
1249        cx: &App,
1250    ) -> Option<BufferEditPrediction<'_>> {
1251        let project_state = self.projects.get_mut(&project.entity_id())?;
1252        if let Some(position) = position
1253            && let Some(buffer) = project_state
1254                .registered_buffers
1255                .get_mut(&buffer.entity_id())
1256        {
1257            buffer.last_position = Some(position);
1258        }
1259
1260        let CurrentEditPrediction {
1261            requested_by,
1262            prediction,
1263            ..
1264        } = project_state.current_prediction.as_ref()?;
1265
1266        if prediction.targets_buffer(buffer.read(cx)) {
1267            Some(BufferEditPrediction::Local { prediction })
1268        } else {
1269            let show_jump = match requested_by {
1270                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1271                    requested_by_buffer_id == &buffer.entity_id()
1272                }
1273                PredictionRequestedBy::DiagnosticsUpdate => true,
1274            };
1275
1276            if show_jump {
1277                Some(BufferEditPrediction::Jump { prediction })
1278            } else {
1279                None
1280            }
1281        }
1282    }
1283
1284    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1285        let Some(current_prediction) = self
1286            .projects
1287            .get_mut(&project.entity_id())
1288            .and_then(|project_state| project_state.current_prediction.take())
1289        else {
1290            return;
1291        };
1292
1293        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1294
1295        // can't hold &mut project_state ref across report_changes_for_buffer_call
1296        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1297            return;
1298        };
1299
1300        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1301            project_state.cancel_pending_prediction(pending_prediction, cx);
1302        }
1303
1304        match self.edit_prediction_model {
1305            EditPredictionModel::Sweep => {
1306                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1307            }
1308            EditPredictionModel::Mercury => {
1309                mercury::edit_prediction_accepted(
1310                    current_prediction.prediction.id,
1311                    self.client.http_client(),
1312                    cx,
1313                );
1314            }
1315            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1316                let is_cloud = !matches!(
1317                    all_language_settings(None, cx).edit_predictions.provider,
1318                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1319                );
1320                if is_cloud {
1321                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1322                }
1323            }
1324            EditPredictionModel::Fim { .. } => {}
1325        }
1326    }
1327
1328    async fn handle_rejected_predictions(
1329        rx: UnboundedReceiver<EditPredictionRejection>,
1330        client: Arc<Client>,
1331        llm_token: LlmApiToken,
1332        app_version: Version,
1333        background_executor: BackgroundExecutor,
1334    ) {
1335        let mut rx = std::pin::pin!(rx.peekable());
1336        let mut batched = Vec::new();
1337
1338        while let Some(rejection) = rx.next().await {
1339            batched.push(rejection);
1340
1341            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1342                select_biased! {
1343                    next = rx.as_mut().peek().fuse() => {
1344                        if next.is_some() {
1345                            continue;
1346                        }
1347                    }
1348                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1349                }
1350            }
1351
1352            let url = client
1353                .http_client()
1354                .build_zed_llm_url("/predict_edits/reject", &[])
1355                .unwrap();
1356
1357            let flush_count = batched
1358                .len()
1359                // in case items have accumulated after failure
1360                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1361            let start = batched.len() - flush_count;
1362
1363            let body = RejectEditPredictionsBodyRef {
1364                rejections: &batched[start..],
1365            };
1366
1367            let result = Self::send_api_request::<()>(
1368                |builder| {
1369                    let req = builder
1370                        .uri(url.as_ref())
1371                        .body(serde_json::to_string(&body)?.into());
1372                    anyhow::Ok(req?)
1373                },
1374                client.clone(),
1375                llm_token.clone(),
1376                app_version.clone(),
1377                true,
1378            )
1379            .await;
1380
1381            if result.log_err().is_some() {
1382                batched.drain(start..);
1383            }
1384        }
1385    }
1386
1387    fn reject_current_prediction(
1388        &mut self,
1389        reason: EditPredictionRejectReason,
1390        project: &Entity<Project>,
1391        cx: &App,
1392    ) {
1393        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1394            project_state.pending_predictions.clear();
1395            if let Some(prediction) = project_state.current_prediction.take() {
1396                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1397            }
1398        };
1399    }
1400
1401    fn did_show_current_prediction(
1402        &mut self,
1403        project: &Entity<Project>,
1404        display_type: edit_prediction_types::SuggestionDisplayType,
1405        cx: &mut Context<Self>,
1406    ) {
1407        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1408            return;
1409        };
1410
1411        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1412            return;
1413        };
1414
1415        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1416        let previous_shown_with = current_prediction.shown_with;
1417
1418        if previous_shown_with.is_none() || !is_jump {
1419            current_prediction.shown_with = Some(display_type);
1420        }
1421
1422        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1423
1424        if is_first_non_jump_show {
1425            current_prediction.was_shown = true;
1426        }
1427
1428        let display_type_changed = previous_shown_with != Some(display_type);
1429
1430        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1431            sweep_ai::edit_prediction_shown(
1432                &self.sweep_ai,
1433                self.client.clone(),
1434                &current_prediction.prediction,
1435                display_type,
1436                cx,
1437            );
1438        }
1439
1440        if is_first_non_jump_show {
1441            self.shown_predictions
1442                .push_front(current_prediction.prediction.clone());
1443            if self.shown_predictions.len() > 50 {
1444                let completion = self.shown_predictions.pop_back().unwrap();
1445                self.rated_predictions.remove(&completion.id);
1446            }
1447        }
1448    }
1449
1450    fn reject_prediction(
1451        &mut self,
1452        prediction_id: EditPredictionId,
1453        reason: EditPredictionRejectReason,
1454        was_shown: bool,
1455        cx: &App,
1456    ) {
1457        match self.edit_prediction_model {
1458            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1459                let is_cloud = !matches!(
1460                    all_language_settings(None, cx).edit_predictions.provider,
1461                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1462                );
1463                if is_cloud {
1464                    self.reject_predictions_tx
1465                        .unbounded_send(EditPredictionRejection {
1466                            request_id: prediction_id.to_string(),
1467                            reason,
1468                            was_shown,
1469                        })
1470                        .log_err();
1471                }
1472            }
1473            EditPredictionModel::Mercury => {
1474                mercury::edit_prediction_rejected(
1475                    prediction_id,
1476                    was_shown,
1477                    reason,
1478                    self.client.http_client(),
1479                    cx,
1480                );
1481            }
1482            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1483        }
1484    }
1485
1486    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1487        self.projects
1488            .get(&project.entity_id())
1489            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1490    }
1491
1492    pub fn refresh_prediction_from_buffer(
1493        &mut self,
1494        project: Entity<Project>,
1495        buffer: Entity<Buffer>,
1496        position: language::Anchor,
1497        cx: &mut Context<Self>,
1498    ) {
1499        self.queue_prediction_refresh(
1500            project.clone(),
1501            PredictEditsRequestTrigger::Other,
1502            buffer.entity_id(),
1503            cx,
1504            move |this, cx| {
1505                let Some(request_task) = this
1506                    .update(cx, |this, cx| {
1507                        this.request_prediction(
1508                            &project,
1509                            &buffer,
1510                            position,
1511                            PredictEditsRequestTrigger::Other,
1512                            cx,
1513                        )
1514                    })
1515                    .log_err()
1516                else {
1517                    return Task::ready(anyhow::Ok(None));
1518                };
1519
1520                cx.spawn(async move |_cx| {
1521                    request_task.await.map(|prediction_result| {
1522                        prediction_result.map(|prediction_result| {
1523                            (
1524                                prediction_result,
1525                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1526                            )
1527                        })
1528                    })
1529                })
1530            },
1531        )
1532    }
1533
1534    pub fn refresh_prediction_from_diagnostics(
1535        &mut self,
1536        project: Entity<Project>,
1537        scope: DiagnosticSearchScope,
1538        cx: &mut Context<Self>,
1539    ) {
1540        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1541            return;
1542        }
1543
1544        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1545            return;
1546        };
1547
1548        // Prefer predictions from buffer
1549        if project_state.current_prediction.is_some() {
1550            return;
1551        }
1552
1553        self.queue_prediction_refresh(
1554            project.clone(),
1555            PredictEditsRequestTrigger::Diagnostics,
1556            project.entity_id(),
1557            cx,
1558            move |this, cx| {
1559                let Some((active_buffer, snapshot, cursor_point)) = this
1560                    .read_with(cx, |this, cx| {
1561                        let project_state = this.projects.get(&project.entity_id())?;
1562                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1563                        let snapshot = buffer.read(cx).snapshot();
1564
1565                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1566                            return None;
1567                        }
1568
1569                        let cursor_point = position
1570                            .map(|pos| pos.to_point(&snapshot))
1571                            .unwrap_or_default();
1572
1573                        Some((buffer, snapshot, cursor_point))
1574                    })
1575                    .log_err()
1576                    .flatten()
1577                else {
1578                    return Task::ready(anyhow::Ok(None));
1579                };
1580
1581                cx.spawn(async move |cx| {
1582                    let diagnostic_search_range = match scope {
1583                        DiagnosticSearchScope::Local => {
1584                            let diagnostic_search_start =
1585                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1586                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1587                            Point::new(diagnostic_search_start, 0)
1588                                ..Point::new(diagnostic_search_end, 0)
1589                        }
1590                        DiagnosticSearchScope::Global => Default::default(),
1591                    };
1592
1593                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1594                        active_buffer,
1595                        &snapshot,
1596                        diagnostic_search_range,
1597                        cursor_point,
1598                        &project,
1599                        cx,
1600                    )
1601                    .await?
1602                    else {
1603                        return anyhow::Ok(None);
1604                    };
1605
1606                    let Some(prediction_result) = this
1607                        .update(cx, |this, cx| {
1608                            this.request_prediction(
1609                                &project,
1610                                &jump_buffer,
1611                                jump_position,
1612                                PredictEditsRequestTrigger::Diagnostics,
1613                                cx,
1614                            )
1615                        })?
1616                        .await?
1617                    else {
1618                        return anyhow::Ok(None);
1619                    };
1620
1621                    this.update(cx, |this, cx| {
1622                        Some((
1623                            if this
1624                                .get_or_init_project(&project, cx)
1625                                .current_prediction
1626                                .is_none()
1627                            {
1628                                prediction_result
1629                            } else {
1630                                EditPredictionResult {
1631                                    id: prediction_result.id,
1632                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1633                                }
1634                            },
1635                            PredictionRequestedBy::DiagnosticsUpdate,
1636                        ))
1637                    })
1638                })
1639            },
1640        );
1641    }
1642
1643    fn predictions_enabled_at(
1644        snapshot: &BufferSnapshot,
1645        position: Option<language::Anchor>,
1646        cx: &App,
1647    ) -> bool {
1648        let file = snapshot.file();
1649        let all_settings = all_language_settings(file, cx);
1650        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1651            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1652        {
1653            return false;
1654        }
1655
1656        if let Some(last_position) = position {
1657            let settings = snapshot.settings_at(last_position, cx);
1658
1659            if !settings.edit_predictions_disabled_in.is_empty()
1660                && let Some(scope) = snapshot.language_scope_at(last_position)
1661                && let Some(scope_name) = scope.override_name()
1662                && settings
1663                    .edit_predictions_disabled_in
1664                    .iter()
1665                    .any(|s| s == scope_name)
1666            {
1667                return false;
1668            }
1669        }
1670
1671        true
1672    }
1673
1674    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1675}
1676
1677fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1678    match provider {
1679        EditPredictionProvider::Zed
1680        | EditPredictionProvider::Sweep
1681        | EditPredictionProvider::Mercury
1682        | EditPredictionProvider::Ollama
1683        | EditPredictionProvider::OpenAiCompatibleApi
1684        | EditPredictionProvider::Experimental(_) => true,
1685        EditPredictionProvider::None
1686        | EditPredictionProvider::Copilot
1687        | EditPredictionProvider::Supermaven
1688        | EditPredictionProvider::Codestral => false,
1689    }
1690}
1691
1692impl EditPredictionStore {
1693    fn queue_prediction_refresh(
1694        &mut self,
1695        project: Entity<Project>,
1696        request_trigger: PredictEditsRequestTrigger,
1697        throttle_entity: EntityId,
1698        cx: &mut Context<Self>,
1699        do_refresh: impl FnOnce(
1700            WeakEntity<Self>,
1701            &mut AsyncApp,
1702        )
1703            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1704        + 'static,
1705    ) {
1706        fn select_throttle(
1707            project_state: &mut ProjectState,
1708            request_trigger: PredictEditsRequestTrigger,
1709        ) -> &mut Option<(EntityId, Instant)> {
1710            match request_trigger {
1711                PredictEditsRequestTrigger::Diagnostics => {
1712                    &mut project_state.last_jump_prediction_refresh
1713                }
1714                _ => &mut project_state.last_edit_prediction_refresh,
1715            }
1716        }
1717
1718        let (needs_acceptance_tracking, max_pending_predictions) =
1719            match all_language_settings(None, cx).edit_predictions.provider {
1720                EditPredictionProvider::Zed
1721                | EditPredictionProvider::Sweep
1722                | EditPredictionProvider::Mercury
1723                | EditPredictionProvider::Experimental(_) => (true, 2),
1724                EditPredictionProvider::Ollama => (false, 1),
1725                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1726                EditPredictionProvider::None
1727                | EditPredictionProvider::Copilot
1728                | EditPredictionProvider::Supermaven
1729                | EditPredictionProvider::Codestral => {
1730                    log::error!("queue_prediction_refresh called with non-store provider");
1731                    return;
1732                }
1733            };
1734
1735        let drop_on_cancel = !needs_acceptance_tracking;
1736        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1737        let project_state = self.get_or_init_project(&project, cx);
1738        let pending_prediction_id = project_state.next_pending_prediction_id;
1739        project_state.next_pending_prediction_id += 1;
1740        let last_request = *select_throttle(project_state, request_trigger);
1741
1742        let task = cx.spawn(async move |this, cx| {
1743            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1744                if throttle_entity != last_entity {
1745                    return None;
1746                }
1747                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1748            }) {
1749                cx.background_executor().timer(timeout).await;
1750            }
1751
1752            // If this task was cancelled before the throttle timeout expired,
1753            // do not perform a request.
1754            let mut is_cancelled = true;
1755            this.update(cx, |this, cx| {
1756                let project_state = this.get_or_init_project(&project, cx);
1757                let was_cancelled = project_state
1758                    .cancelled_predictions
1759                    .remove(&pending_prediction_id);
1760                if !was_cancelled {
1761                    let new_refresh = (throttle_entity, Instant::now());
1762                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
1763                    is_cancelled = false;
1764                }
1765            })
1766            .ok();
1767            if is_cancelled {
1768                return None;
1769            }
1770
1771            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1772            let new_prediction_id = new_prediction_result
1773                .as_ref()
1774                .map(|(prediction, _)| prediction.id.clone());
1775
1776            // When a prediction completes, remove it from the pending list, and cancel
1777            // any pending predictions that were enqueued before it.
1778            this.update(cx, |this, cx| {
1779                let project_state = this.get_or_init_project(&project, cx);
1780
1781                let is_cancelled = project_state
1782                    .cancelled_predictions
1783                    .remove(&pending_prediction_id);
1784
1785                let new_current_prediction = if !is_cancelled
1786                    && let Some((prediction_result, requested_by)) = new_prediction_result
1787                {
1788                    match prediction_result.prediction {
1789                        Ok(prediction) => {
1790                            let new_prediction = CurrentEditPrediction {
1791                                requested_by,
1792                                prediction,
1793                                was_shown: false,
1794                                shown_with: None,
1795                            };
1796
1797                            if let Some(current_prediction) =
1798                                project_state.current_prediction.as_ref()
1799                            {
1800                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1801                                {
1802                                    this.reject_current_prediction(
1803                                        EditPredictionRejectReason::Replaced,
1804                                        &project,
1805                                        cx,
1806                                    );
1807
1808                                    Some(new_prediction)
1809                                } else {
1810                                    this.reject_prediction(
1811                                        new_prediction.prediction.id,
1812                                        EditPredictionRejectReason::CurrentPreferred,
1813                                        false,
1814                                        cx,
1815                                    );
1816                                    None
1817                                }
1818                            } else {
1819                                Some(new_prediction)
1820                            }
1821                        }
1822                        Err(reject_reason) => {
1823                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1824                            None
1825                        }
1826                    }
1827                } else {
1828                    None
1829                };
1830
1831                let project_state = this.get_or_init_project(&project, cx);
1832
1833                if let Some(new_prediction) = new_current_prediction {
1834                    project_state.current_prediction = Some(new_prediction);
1835                }
1836
1837                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1838                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1839                    if pending_prediction.id == pending_prediction_id {
1840                        pending_predictions.remove(ix);
1841                        for pending_prediction in pending_predictions.drain(0..ix) {
1842                            project_state.cancel_pending_prediction(pending_prediction, cx)
1843                        }
1844                        break;
1845                    }
1846                }
1847                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1848                cx.notify();
1849            })
1850            .ok();
1851
1852            new_prediction_id
1853        });
1854
1855        if project_state.pending_predictions.len() < max_pending_predictions {
1856            project_state.pending_predictions.push(PendingPrediction {
1857                id: pending_prediction_id,
1858                task,
1859                drop_on_cancel,
1860            });
1861        } else {
1862            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1863            project_state.pending_predictions.push(PendingPrediction {
1864                id: pending_prediction_id,
1865                task,
1866                drop_on_cancel,
1867            });
1868            project_state.cancel_pending_prediction(pending_prediction, cx);
1869        }
1870    }
1871
1872    pub fn request_prediction(
1873        &mut self,
1874        project: &Entity<Project>,
1875        active_buffer: &Entity<Buffer>,
1876        position: language::Anchor,
1877        trigger: PredictEditsRequestTrigger,
1878        cx: &mut Context<Self>,
1879    ) -> Task<Result<Option<EditPredictionResult>>> {
1880        self.request_prediction_internal(
1881            project.clone(),
1882            active_buffer.clone(),
1883            position,
1884            trigger,
1885            cx.has_flag::<Zeta2FeatureFlag>(),
1886            cx,
1887        )
1888    }
1889
1890    fn request_prediction_internal(
1891        &mut self,
1892        project: Entity<Project>,
1893        active_buffer: Entity<Buffer>,
1894        position: language::Anchor,
1895        trigger: PredictEditsRequestTrigger,
1896        allow_jump: bool,
1897        cx: &mut Context<Self>,
1898    ) -> Task<Result<Option<EditPredictionResult>>> {
1899        self.get_or_init_project(&project, cx);
1900        let project_state = self.projects.get(&project.entity_id()).unwrap();
1901        let stored_events = project_state.events(cx);
1902        let has_events = !stored_events.is_empty();
1903        let events: Vec<Arc<zeta_prompt::Event>> =
1904            stored_events.into_iter().map(|e| e.event).collect();
1905        let debug_tx = project_state.debug_tx.clone();
1906
1907        let snapshot = active_buffer.read(cx).snapshot();
1908        let cursor_point = position.to_point(&snapshot);
1909        let current_offset = position.to_offset(&snapshot);
1910
1911        let mut user_actions: Vec<UserActionRecord> =
1912            project_state.user_actions.iter().cloned().collect();
1913
1914        if let Some(last_action) = user_actions.last() {
1915            if last_action.buffer_id == active_buffer.entity_id()
1916                && current_offset != last_action.offset
1917            {
1918                let timestamp_epoch_ms = SystemTime::now()
1919                    .duration_since(UNIX_EPOCH)
1920                    .map(|d| d.as_millis() as u64)
1921                    .unwrap_or(0);
1922                user_actions.push(UserActionRecord {
1923                    action_type: UserActionType::CursorMovement,
1924                    buffer_id: active_buffer.entity_id(),
1925                    line_number: cursor_point.row,
1926                    offset: current_offset,
1927                    timestamp_epoch_ms,
1928                });
1929            }
1930        }
1931        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1932        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1933        let diagnostic_search_range =
1934            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1935
1936        let related_files = self.context_for_project(&project, cx);
1937
1938        let inputs = EditPredictionModelInput {
1939            project: project.clone(),
1940            buffer: active_buffer,
1941            snapshot: snapshot,
1942            position,
1943            events,
1944            related_files,
1945            recent_paths: project_state.recent_paths.clone(),
1946            trigger,
1947            diagnostic_search_range: diagnostic_search_range,
1948            debug_tx,
1949            user_actions,
1950        };
1951
1952        let task = match self.edit_prediction_model {
1953            EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
1954                self,
1955                inputs,
1956                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1957                cx,
1958            ),
1959            EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
1960                self,
1961                inputs,
1962                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1963                cx,
1964            ),
1965            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
1966            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1967            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1968        };
1969
1970        cx.spawn(async move |this, cx| {
1971            let prediction = task.await?;
1972
1973            if prediction.is_none() && allow_jump && has_events {
1974                this.update(cx, |this, cx| {
1975                    this.refresh_prediction_from_diagnostics(
1976                        project,
1977                        DiagnosticSearchScope::Local,
1978                        cx,
1979                    );
1980                })?;
1981                return anyhow::Ok(None);
1982            }
1983
1984            Ok(prediction)
1985        })
1986    }
1987
1988    pub(crate) async fn next_diagnostic_location(
1989        active_buffer: Entity<Buffer>,
1990        active_buffer_snapshot: &BufferSnapshot,
1991        active_buffer_diagnostic_search_range: Range<Point>,
1992        active_buffer_cursor_point: Point,
1993        project: &Entity<Project>,
1994        cx: &mut AsyncApp,
1995    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1996        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
1997            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
1998            .flat_map(|(_, _, _, selections)| {
1999                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2000            })
2001            .collect();
2002
2003        let mut jump_location = active_buffer_snapshot
2004            .diagnostic_groups(None)
2005            .into_iter()
2006            .filter_map(|(_, group)| {
2007                let range = &group.entries[group.primary_ix]
2008                    .range
2009                    .to_point(&active_buffer_snapshot);
2010                if range.overlaps(&active_buffer_diagnostic_search_range) {
2011                    return None;
2012                }
2013                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2014                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2015                });
2016                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2017                    <= DIAGNOSTIC_LINES_RANGE;
2018                if near_collaborator && !near_local {
2019                    return None;
2020                }
2021                Some(range.start)
2022            })
2023            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2024            .map(|position| {
2025                (
2026                    active_buffer.clone(),
2027                    active_buffer_snapshot.anchor_before(position),
2028                )
2029            });
2030
2031        if jump_location.is_none() {
2032            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2033                let file = buffer.file()?;
2034
2035                Some(ProjectPath {
2036                    worktree_id: file.worktree_id(cx),
2037                    path: file.path().clone(),
2038                })
2039            });
2040
2041            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2042                project
2043                    .diagnostic_summaries(false, cx)
2044                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2045                    .map(|(path, _, _)| {
2046                        let shared_prefix = path
2047                            .path
2048                            .components()
2049                            .zip(
2050                                active_buffer_path
2051                                    .as_ref()
2052                                    .map(|p| p.path.components())
2053                                    .unwrap_or_default(),
2054                            )
2055                            .take_while(|(a, b)| a == b)
2056                            .count();
2057                        (path, shared_prefix)
2058                    })
2059                    .collect()
2060            });
2061
2062            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2063
2064            for (path, _) in candidates {
2065                let candidate_buffer = project
2066                    .update(cx, |project, cx| project.open_buffer(path, cx))
2067                    .await?;
2068
2069                let (has_collaborators, diagnostic_position) =
2070                    candidate_buffer.read_with(cx, |buffer, _cx| {
2071                        let snapshot = buffer.snapshot();
2072                        let has_collaborators = snapshot
2073                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2074                            .next()
2075                            .is_some();
2076                        let position = buffer
2077                            .buffer_diagnostics(None)
2078                            .into_iter()
2079                            .min_by_key(|entry| entry.diagnostic.severity)
2080                            .map(|entry| entry.range.start);
2081                        (has_collaborators, position)
2082                    });
2083
2084                if has_collaborators {
2085                    continue;
2086                }
2087
2088                if let Some(position) = diagnostic_position {
2089                    jump_location = Some((candidate_buffer, position));
2090                    break;
2091                }
2092            }
2093        }
2094
2095        anyhow::Ok(jump_location)
2096    }
2097
2098    async fn send_raw_llm_request(
2099        request: RawCompletionRequest,
2100        client: Arc<Client>,
2101        custom_url: Option<Arc<Url>>,
2102        llm_token: LlmApiToken,
2103        app_version: Version,
2104    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2105        let url = if let Some(custom_url) = custom_url {
2106            custom_url.as_ref().clone()
2107        } else {
2108            client
2109                .http_client()
2110                .build_zed_llm_url("/predict_edits/raw", &[])?
2111        };
2112
2113        Self::send_api_request(
2114            |builder| {
2115                let req = builder
2116                    .uri(url.as_ref())
2117                    .body(serde_json::to_string(&request)?.into());
2118                Ok(req?)
2119            },
2120            client,
2121            llm_token,
2122            app_version,
2123            true,
2124        )
2125        .await
2126    }
2127
2128    pub(crate) async fn send_v3_request(
2129        input: ZetaPromptInput,
2130        client: Arc<Client>,
2131        llm_token: LlmApiToken,
2132        app_version: Version,
2133        trigger: PredictEditsRequestTrigger,
2134    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2135        let url = client
2136            .http_client()
2137            .build_zed_llm_url("/predict_edits/v3", &[])?;
2138
2139        let request = PredictEditsV3Request { input, trigger };
2140
2141        let json_bytes = serde_json::to_vec(&request)?;
2142        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2143
2144        Self::send_api_request(
2145            |builder| {
2146                let req = builder
2147                    .uri(url.as_ref())
2148                    .header("Content-Encoding", "zstd")
2149                    .body(compressed.clone().into());
2150                Ok(req?)
2151            },
2152            client,
2153            llm_token,
2154            app_version,
2155            true,
2156        )
2157        .await
2158    }
2159
2160    fn handle_api_response<T>(
2161        this: &WeakEntity<Self>,
2162        response: Result<(T, Option<EditPredictionUsage>)>,
2163        cx: &mut gpui::AsyncApp,
2164    ) -> Result<T> {
2165        match response {
2166            Ok((data, usage)) => {
2167                if let Some(usage) = usage {
2168                    this.update(cx, |this, cx| {
2169                        this.user_store.update(cx, |user_store, cx| {
2170                            user_store.update_edit_prediction_usage(usage, cx);
2171                        });
2172                    })
2173                    .ok();
2174                }
2175                Ok(data)
2176            }
2177            Err(err) => {
2178                if err.is::<ZedUpdateRequiredError>() {
2179                    cx.update(|cx| {
2180                        this.update(cx, |this, _cx| {
2181                            this.update_required = true;
2182                        })
2183                        .ok();
2184
2185                        let error_message: SharedString = err.to_string().into();
2186                        show_app_notification(
2187                            NotificationId::unique::<ZedUpdateRequiredError>(),
2188                            cx,
2189                            move |cx| {
2190                                cx.new(|cx| {
2191                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2192                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2193                                })
2194                            },
2195                        );
2196                    });
2197                }
2198                Err(err)
2199            }
2200        }
2201    }
2202
2203    async fn send_api_request<Res>(
2204        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2205        client: Arc<Client>,
2206        llm_token: LlmApiToken,
2207        app_version: Version,
2208        require_auth: bool,
2209    ) -> Result<(Res, Option<EditPredictionUsage>)>
2210    where
2211        Res: DeserializeOwned,
2212    {
2213        let http_client = client.http_client();
2214
2215        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2216            Some(custom_token)
2217        } else if require_auth {
2218            Some(llm_token.acquire(&client).await?)
2219        } else {
2220            llm_token.acquire(&client).await.ok()
2221        };
2222        let mut did_retry = false;
2223
2224        loop {
2225            let request_builder = http_client::Request::builder().method(Method::POST);
2226
2227            let mut request_builder = request_builder
2228                .header("Content-Type", "application/json")
2229                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2230
2231            // Only add Authorization header if we have a token
2232            if let Some(ref token_value) = token {
2233                request_builder =
2234                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2235            }
2236
2237            let request = build(request_builder)?;
2238
2239            let mut response = http_client.send(request).await?;
2240
2241            if let Some(minimum_required_version) = response
2242                .headers()
2243                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2244                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2245            {
2246                anyhow::ensure!(
2247                    app_version >= minimum_required_version,
2248                    ZedUpdateRequiredError {
2249                        minimum_version: minimum_required_version
2250                    }
2251                );
2252            }
2253
2254            if response.status().is_success() {
2255                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2256
2257                let mut body = Vec::new();
2258                response.body_mut().read_to_end(&mut body).await?;
2259                return Ok((serde_json::from_slice(&body)?, usage));
2260            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2261                did_retry = true;
2262                token = Some(llm_token.refresh(&client).await?);
2263            } else {
2264                let mut body = String::new();
2265                response.body_mut().read_to_string(&mut body).await?;
2266                anyhow::bail!(
2267                    "Request failed with status: {:?}\nBody: {}",
2268                    response.status(),
2269                    body
2270                );
2271            }
2272        }
2273    }
2274
2275    pub fn refresh_context(
2276        &mut self,
2277        project: &Entity<Project>,
2278        buffer: &Entity<language::Buffer>,
2279        cursor_position: language::Anchor,
2280        cx: &mut Context<Self>,
2281    ) {
2282        self.get_or_init_project(project, cx)
2283            .context
2284            .update(cx, |store, cx| {
2285                store.refresh(buffer.clone(), cursor_position, cx);
2286            });
2287    }
2288
2289    #[cfg(feature = "cli-support")]
2290    pub fn set_context_for_buffer(
2291        &mut self,
2292        project: &Entity<Project>,
2293        related_files: Vec<RelatedFile>,
2294        cx: &mut Context<Self>,
2295    ) {
2296        self.get_or_init_project(project, cx)
2297            .context
2298            .update(cx, |store, cx| {
2299                store.set_related_files(related_files, cx);
2300            });
2301    }
2302
2303    #[cfg(feature = "cli-support")]
2304    pub fn set_recent_paths_for_project(
2305        &mut self,
2306        project: &Entity<Project>,
2307        paths: impl IntoIterator<Item = project::ProjectPath>,
2308        cx: &mut Context<Self>,
2309    ) {
2310        let project_state = self.get_or_init_project(project, cx);
2311        project_state.recent_paths = paths.into_iter().collect();
2312    }
2313
2314    fn is_file_open_source(
2315        &self,
2316        project: &Entity<Project>,
2317        file: &Arc<dyn File>,
2318        cx: &App,
2319    ) -> bool {
2320        if !file.is_local() || file.is_private() {
2321            return false;
2322        }
2323        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2324            return false;
2325        };
2326        project_state
2327            .license_detection_watchers
2328            .get(&file.worktree_id(cx))
2329            .as_ref()
2330            .is_some_and(|watcher| watcher.is_project_open_source())
2331    }
2332
2333    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2334        self.data_collection_choice.is_enabled(cx)
2335    }
2336
2337    fn load_data_collection_choice() -> DataCollectionChoice {
2338        let choice = KEY_VALUE_STORE
2339            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2340            .log_err()
2341            .flatten();
2342
2343        match choice.as_deref() {
2344            Some("true") => DataCollectionChoice::Enabled,
2345            Some("false") => DataCollectionChoice::Disabled,
2346            Some(_) => {
2347                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2348                DataCollectionChoice::NotAnswered
2349            }
2350            None => DataCollectionChoice::NotAnswered,
2351        }
2352    }
2353
2354    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2355        self.data_collection_choice = self.data_collection_choice.toggle();
2356        let new_choice = self.data_collection_choice;
2357        let is_enabled = new_choice.is_enabled(cx);
2358        db::write_and_log(cx, move || {
2359            KEY_VALUE_STORE.write_kvp(
2360                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2361                is_enabled.to_string(),
2362            )
2363        });
2364    }
2365
2366    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2367        self.shown_predictions.iter()
2368    }
2369
2370    pub fn shown_completions_len(&self) -> usize {
2371        self.shown_predictions.len()
2372    }
2373
2374    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2375        self.rated_predictions.contains(id)
2376    }
2377
2378    pub fn rate_prediction(
2379        &mut self,
2380        prediction: &EditPrediction,
2381        rating: EditPredictionRating,
2382        feedback: String,
2383        cx: &mut Context<Self>,
2384    ) {
2385        let organization = self.user_store.read(cx).current_organization();
2386
2387        self.rated_predictions.insert(prediction.id.clone());
2388
2389        cx.background_spawn({
2390            let client = self.client.clone();
2391            let prediction_id = prediction.id.to_string();
2392            let inputs = serde_json::to_value(&prediction.inputs);
2393            let output = prediction
2394                .edit_preview
2395                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2396            async move {
2397                client
2398                    .cloud_client()
2399                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2400                        organization_id: organization.map(|organization| organization.id.clone()),
2401                        request_id: prediction_id,
2402                        rating: match rating {
2403                            EditPredictionRating::Positive => "positive".to_string(),
2404                            EditPredictionRating::Negative => "negative".to_string(),
2405                        },
2406                        inputs: inputs?,
2407                        output,
2408                        feedback,
2409                    })
2410                    .await?;
2411
2412                anyhow::Ok(())
2413            }
2414        })
2415        .detach_and_log_err(cx);
2416
2417        cx.notify();
2418    }
2419}
2420
2421fn merge_trailing_events_if_needed(
2422    events: &mut VecDeque<StoredEvent>,
2423    end_snapshot: &TextBufferSnapshot,
2424    latest_snapshot: &TextBufferSnapshot,
2425    latest_edit_range: &Range<Anchor>,
2426) {
2427    if let Some(last_event) = events.back() {
2428        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2429            return;
2430        }
2431    }
2432
2433    let mut next_old_event = None;
2434    let mut mergeable_count = 0;
2435    for old_event in events.iter().rev() {
2436        if let Some(next_old_event) = &next_old_event
2437            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2438        {
2439            break;
2440        }
2441        mergeable_count += 1;
2442        next_old_event = Some(old_event);
2443    }
2444
2445    if mergeable_count <= 1 {
2446        return;
2447    }
2448
2449    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2450    let oldest_event = events_to_merge.peek().unwrap();
2451    let oldest_snapshot = oldest_event.old_snapshot.clone();
2452
2453    if let Some((diff, edited_range)) =
2454        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2455    {
2456        let merged_event = match oldest_event.event.as_ref() {
2457            zeta_prompt::Event::BufferChange {
2458                old_path,
2459                path,
2460                in_open_source_repo,
2461                ..
2462            } => StoredEvent {
2463                event: Arc::new(zeta_prompt::Event::BufferChange {
2464                    old_path: old_path.clone(),
2465                    path: path.clone(),
2466                    diff,
2467                    in_open_source_repo: *in_open_source_repo,
2468                    predicted: events_to_merge.all(|e| {
2469                        matches!(
2470                            e.event.as_ref(),
2471                            zeta_prompt::Event::BufferChange {
2472                                predicted: true,
2473                                ..
2474                            }
2475                        )
2476                    }),
2477                }),
2478                old_snapshot: oldest_snapshot.clone(),
2479                edit_range: end_snapshot.anchor_before(edited_range.start)
2480                    ..end_snapshot.anchor_before(edited_range.end),
2481            },
2482        };
2483        events.truncate(events.len() - mergeable_count);
2484        events.push_back(merged_event);
2485    }
2486}
2487
2488pub(crate) fn filter_redundant_excerpts(
2489    mut related_files: Vec<RelatedFile>,
2490    cursor_path: &Path,
2491    cursor_row_range: Range<u32>,
2492) -> Vec<RelatedFile> {
2493    for file in &mut related_files {
2494        if file.path.as_ref() == cursor_path {
2495            file.excerpts.retain(|excerpt| {
2496                excerpt.row_range.start < cursor_row_range.start
2497                    || excerpt.row_range.end > cursor_row_range.end
2498            });
2499        }
2500    }
2501    related_files.retain(|file| !file.excerpts.is_empty());
2502    related_files
2503}
2504
2505#[derive(Error, Debug)]
2506#[error(
2507    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2508)]
2509pub struct ZedUpdateRequiredError {
2510    minimum_version: Version,
2511}
2512
2513#[derive(Debug, Clone, Copy)]
2514pub enum DataCollectionChoice {
2515    NotAnswered,
2516    Enabled,
2517    Disabled,
2518}
2519
2520impl DataCollectionChoice {
2521    pub fn is_enabled(self, cx: &App) -> bool {
2522        if cx.is_staff() {
2523            return true;
2524        }
2525        match self {
2526            Self::Enabled => true,
2527            Self::NotAnswered | Self::Disabled => false,
2528        }
2529    }
2530
2531    #[must_use]
2532    pub fn toggle(&self) -> DataCollectionChoice {
2533        match self {
2534            Self::Enabled => Self::Disabled,
2535            Self::Disabled => Self::Enabled,
2536            Self::NotAnswered => Self::Enabled,
2537        }
2538    }
2539}
2540
2541impl From<bool> for DataCollectionChoice {
2542    fn from(value: bool) -> Self {
2543        match value {
2544            true => DataCollectionChoice::Enabled,
2545            false => DataCollectionChoice::Disabled,
2546        }
2547    }
2548}
2549
2550struct ZedPredictUpsell;
2551
2552impl Dismissable for ZedPredictUpsell {
2553    const KEY: &'static str = "dismissed-edit-predict-upsell";
2554
2555    fn dismissed() -> bool {
2556        // To make this backwards compatible with older versions of Zed, we
2557        // check if the user has seen the previous Edit Prediction Onboarding
2558        // before, by checking the data collection choice which was written to
2559        // the database once the user clicked on "Accept and Enable"
2560        if KEY_VALUE_STORE
2561            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2562            .log_err()
2563            .is_some_and(|s| s.is_some())
2564        {
2565            return true;
2566        }
2567
2568        KEY_VALUE_STORE
2569            .read_kvp(Self::KEY)
2570            .log_err()
2571            .is_some_and(|s| s.is_some())
2572    }
2573}
2574
2575pub fn should_show_upsell_modal() -> bool {
2576    !ZedPredictUpsell::dismissed()
2577}
2578
2579pub fn init(cx: &mut App) {
2580    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2581        workspace.register_action(
2582            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2583                ZedPredictModal::toggle(
2584                    workspace,
2585                    workspace.user_store().clone(),
2586                    workspace.client().clone(),
2587                    window,
2588                    cx,
2589                )
2590            },
2591        );
2592
2593        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2594            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2595                settings
2596                    .project
2597                    .all_languages
2598                    .edit_predictions
2599                    .get_or_insert_default()
2600                    .provider = Some(EditPredictionProvider::None)
2601            });
2602        });
2603        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2604            EditPredictionStore::try_global(cx).and_then(|store| {
2605                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2606            })
2607        }
2608
2609        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2610            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2611                copilot_ui::initiate_sign_in(copilot, window, cx);
2612            }
2613        });
2614        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2615            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2616                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2617            }
2618        });
2619        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2620            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2621                copilot_ui::initiate_sign_out(copilot, window, cx);
2622            }
2623        });
2624    })
2625    .detach();
2626}