edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56
  57pub mod cursor_excerpt;
  58pub mod example_spec;
  59pub mod fim;
  60mod license_detection;
  61pub mod mercury;
  62pub mod ollama;
  63mod onboarding_modal;
  64pub mod open_ai_response;
  65mod prediction;
  66pub mod sweep_ai;
  67
  68pub mod udiff;
  69
  70mod capture_example;
  71pub mod open_ai_compatible;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::example_spec::ExampleSpec;
  79use crate::license_detection::LicenseDetectionWatcher;
  80use crate::mercury::Mercury;
  81use crate::onboarding_modal::ZedPredictModal;
  82pub use crate::prediction::EditPrediction;
  83pub use crate::prediction::EditPredictionId;
  84use crate::prediction::EditPredictionResult;
  85pub use crate::sweep_ai::SweepAi;
  86pub use capture_example::capture_example;
  87pub use language_model::ApiKeyState;
  88pub use telemetry_events::EditPredictionRating;
  89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  90
  91actions!(
  92    edit_prediction,
  93    [
  94        /// Resets the edit prediction onboarding state.
  95        ResetOnboarding,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 110
 111pub struct EditPredictionJumpsFeatureFlag;
 112
 113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 114    const NAME: &'static str = "edit_prediction_jumps";
 115}
 116
 117#[derive(Clone)]
 118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 119
 120impl Global for EditPredictionStoreGlobal {}
 121
 122/// Configuration for using the raw Zeta2 endpoint.
 123/// When set, the client uses the raw endpoint and constructs the prompt itself.
 124/// The version is also used as the Baseten environment name (lowercased).
 125#[derive(Clone)]
 126pub struct Zeta2RawConfig {
 127    pub model_id: Option<String>,
 128    pub environment: Option<String>,
 129    pub format: ZetaFormat,
 130}
 131
 132pub struct EditPredictionStore {
 133    client: Arc<Client>,
 134    user_store: Entity<UserStore>,
 135    llm_token: LlmApiToken,
 136    _fetch_experiments_task: Task<()>,
 137    projects: HashMap<EntityId, ProjectState>,
 138    update_required: bool,
 139    edit_prediction_model: EditPredictionModel,
 140    zeta2_raw_config: Option<Zeta2RawConfig>,
 141    preferred_experiment: Option<String>,
 142    available_experiments: Vec<String>,
 143    pub sweep_ai: SweepAi,
 144    pub mercury: Mercury,
 145    data_collection_choice: DataCollectionChoice,
 146    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 147    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 148    shown_predictions: VecDeque<EditPrediction>,
 149    rated_predictions: HashSet<EditPredictionId>,
 150    #[cfg(test)]
 151    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 152}
 153
 154pub(crate) struct EditPredictionRejectionPayload {
 155    rejection: EditPredictionRejection,
 156    organization_id: Option<OrganizationId>,
 157}
 158
 159#[derive(Copy, Clone, PartialEq, Eq)]
 160pub enum EditPredictionModel {
 161    Zeta,
 162    Fim { format: EditPredictionPromptFormat },
 163    Sweep,
 164    Mercury,
 165}
 166
 167#[derive(Clone)]
 168pub struct EditPredictionModelInput {
 169    project: Entity<Project>,
 170    buffer: Entity<Buffer>,
 171    snapshot: BufferSnapshot,
 172    position: Anchor,
 173    events: Vec<Arc<zeta_prompt::Event>>,
 174    related_files: Vec<RelatedFile>,
 175    recent_paths: VecDeque<ProjectPath>,
 176    trigger: PredictEditsRequestTrigger,
 177    diagnostic_search_range: Range<Point>,
 178    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 179    can_collect_data: bool,
 180    is_open_source: bool,
 181    pub user_actions: Vec<UserActionRecord>,
 182}
 183
 184#[derive(Debug)]
 185pub enum DebugEvent {
 186    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 187    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 188    EditPredictionStarted(EditPredictionStartedDebugEvent),
 189    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 190}
 191
 192#[derive(Debug)]
 193pub struct ContextRetrievalStartedDebugEvent {
 194    pub project_entity_id: EntityId,
 195    pub timestamp: Instant,
 196    pub search_prompt: String,
 197}
 198
 199#[derive(Debug)]
 200pub struct ContextRetrievalFinishedDebugEvent {
 201    pub project_entity_id: EntityId,
 202    pub timestamp: Instant,
 203    pub metadata: Vec<(&'static str, SharedString)>,
 204}
 205
 206#[derive(Debug)]
 207pub struct EditPredictionStartedDebugEvent {
 208    pub buffer: WeakEntity<Buffer>,
 209    pub position: Anchor,
 210    pub prompt: Option<String>,
 211}
 212
 213#[derive(Debug)]
 214pub struct EditPredictionFinishedDebugEvent {
 215    pub buffer: WeakEntity<Buffer>,
 216    pub position: Anchor,
 217    pub model_output: Option<String>,
 218}
 219
 220const USER_ACTION_HISTORY_SIZE: usize = 16;
 221
 222#[derive(Clone, Debug)]
 223pub struct UserActionRecord {
 224    pub action_type: UserActionType,
 225    pub buffer_id: EntityId,
 226    pub line_number: u32,
 227    pub offset: usize,
 228    pub timestamp_epoch_ms: u64,
 229}
 230
 231#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 232pub enum UserActionType {
 233    InsertChar,
 234    InsertSelection,
 235    DeleteChar,
 236    DeleteSelection,
 237    CursorMovement,
 238}
 239
 240/// An event with associated metadata for reconstructing buffer state.
 241#[derive(Clone)]
 242pub struct StoredEvent {
 243    pub event: Arc<zeta_prompt::Event>,
 244    pub old_snapshot: TextBufferSnapshot,
 245    pub edit_range: Range<Anchor>,
 246}
 247
 248impl StoredEvent {
 249    fn can_merge(
 250        &self,
 251        next_old_event: &&&StoredEvent,
 252        new_snapshot: &TextBufferSnapshot,
 253        last_edit_range: &Range<Anchor>,
 254    ) -> bool {
 255        // Events must be for the same buffer
 256        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 257            return false;
 258        }
 259        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 260            return false;
 261        }
 262
 263        let a_is_predicted = matches!(
 264            self.event.as_ref(),
 265            zeta_prompt::Event::BufferChange {
 266                predicted: true,
 267                ..
 268            }
 269        );
 270        let b_is_predicted = matches!(
 271            next_old_event.event.as_ref(),
 272            zeta_prompt::Event::BufferChange {
 273                predicted: true,
 274                ..
 275            }
 276        );
 277
 278        // If events come from the same source (both predicted or both manual) then
 279        // we would have coalesced them already.
 280        if a_is_predicted == b_is_predicted {
 281            return false;
 282        }
 283
 284        let left_range = self.edit_range.to_point(new_snapshot);
 285        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 286        let latest_range = last_edit_range.to_point(&new_snapshot);
 287
 288        // Events near to the latest edit are not merged if their sources differ.
 289        if lines_between_ranges(&left_range, &latest_range)
 290            .min(lines_between_ranges(&right_range, &latest_range))
 291            <= CHANGE_GROUPING_LINE_SPAN
 292        {
 293            return false;
 294        }
 295
 296        // Events that are distant from each other are not merged.
 297        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 298            return false;
 299        }
 300
 301        true
 302    }
 303}
 304
 305fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 306    if left.start > right.end {
 307        return left.start.row - right.end.row;
 308    }
 309    if right.start > left.end {
 310        return right.start.row - left.end.row;
 311    }
 312    0
 313}
 314
 315struct ProjectState {
 316    events: VecDeque<StoredEvent>,
 317    last_event: Option<LastEvent>,
 318    recent_paths: VecDeque<ProjectPath>,
 319    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 320    current_prediction: Option<CurrentEditPrediction>,
 321    next_pending_prediction_id: usize,
 322    pending_predictions: ArrayVec<PendingPrediction, 2>,
 323    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 324    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 325    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 326    cancelled_predictions: HashSet<usize>,
 327    context: Entity<RelatedExcerptStore>,
 328    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 329    user_actions: VecDeque<UserActionRecord>,
 330    _subscriptions: [gpui::Subscription; 2],
 331    copilot: Option<Entity<Copilot>>,
 332}
 333
 334impl ProjectState {
 335    fn record_user_action(&mut self, action: UserActionRecord) {
 336        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 337            self.user_actions.pop_front();
 338        }
 339        self.user_actions.push_back(action);
 340    }
 341
 342    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 343        self.events
 344            .iter()
 345            .cloned()
 346            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 347                let (one, two) = event.split_by_pause();
 348                let one = one.finalize(&self.license_detection_watchers, cx);
 349                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 350                one.into_iter().chain(two)
 351            }))
 352            .collect()
 353    }
 354
 355    fn cancel_pending_prediction(
 356        &mut self,
 357        pending_prediction: PendingPrediction,
 358        cx: &mut Context<EditPredictionStore>,
 359    ) {
 360        self.cancelled_predictions.insert(pending_prediction.id);
 361
 362        if pending_prediction.drop_on_cancel {
 363            drop(pending_prediction.task);
 364        } else {
 365            cx.spawn(async move |this, cx| {
 366                let Some(prediction_id) = pending_prediction.task.await else {
 367                    return;
 368                };
 369
 370                this.update(cx, |this, cx| {
 371                    this.reject_prediction(
 372                        prediction_id,
 373                        EditPredictionRejectReason::Canceled,
 374                        false,
 375                        None,
 376                        cx,
 377                    );
 378                })
 379                .ok();
 380            })
 381            .detach()
 382        }
 383    }
 384
 385    fn active_buffer(
 386        &self,
 387        project: &Entity<Project>,
 388        cx: &App,
 389    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 390        let project = project.read(cx);
 391        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 392        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 393        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 394        Some((active_buffer, registered_buffer.last_position))
 395    }
 396}
 397
 398#[derive(Debug, Clone)]
 399struct CurrentEditPrediction {
 400    pub requested_by: PredictionRequestedBy,
 401    pub prediction: EditPrediction,
 402    pub was_shown: bool,
 403    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 404}
 405
 406impl CurrentEditPrediction {
 407    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 408        let Some(new_edits) = self
 409            .prediction
 410            .interpolate(&self.prediction.buffer.read(cx))
 411        else {
 412            return false;
 413        };
 414
 415        if self.prediction.buffer != old_prediction.prediction.buffer {
 416            return true;
 417        }
 418
 419        let Some(old_edits) = old_prediction
 420            .prediction
 421            .interpolate(&old_prediction.prediction.buffer.read(cx))
 422        else {
 423            return true;
 424        };
 425
 426        let requested_by_buffer_id = self.requested_by.buffer_id();
 427
 428        // This reduces the occurrence of UI thrash from replacing edits
 429        //
 430        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 431        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 432            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 433            && old_edits.len() == 1
 434            && new_edits.len() == 1
 435        {
 436            let (old_range, old_text) = &old_edits[0];
 437            let (new_range, new_text) = &new_edits[0];
 438            new_range == old_range && new_text.starts_with(old_text.as_ref())
 439        } else {
 440            true
 441        }
 442    }
 443}
 444
 445#[derive(Debug, Clone)]
 446enum PredictionRequestedBy {
 447    DiagnosticsUpdate,
 448    Buffer(EntityId),
 449}
 450
 451impl PredictionRequestedBy {
 452    pub fn buffer_id(&self) -> Option<EntityId> {
 453        match self {
 454            PredictionRequestedBy::DiagnosticsUpdate => None,
 455            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 456        }
 457    }
 458}
 459
 460const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 461
 462#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 463pub enum DiagnosticSearchScope {
 464    Local,
 465    Global,
 466}
 467
 468#[derive(Debug)]
 469struct PendingPrediction {
 470    id: usize,
 471    task: Task<Option<EditPredictionId>>,
 472    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 473    /// If false, the task is awaited to completion so rejection can be reported.
 474    drop_on_cancel: bool,
 475}
 476
 477/// A prediction from the perspective of a buffer.
 478#[derive(Debug)]
 479enum BufferEditPrediction<'a> {
 480    Local { prediction: &'a EditPrediction },
 481    Jump { prediction: &'a EditPrediction },
 482}
 483
 484#[cfg(test)]
 485impl std::ops::Deref for BufferEditPrediction<'_> {
 486    type Target = EditPrediction;
 487
 488    fn deref(&self) -> &Self::Target {
 489        match self {
 490            BufferEditPrediction::Local { prediction } => prediction,
 491            BufferEditPrediction::Jump { prediction } => prediction,
 492        }
 493    }
 494}
 495
 496#[derive(Clone)]
 497struct PendingSettledPrediction {
 498    request_id: EditPredictionId,
 499    editable_anchor_range: Range<Anchor>,
 500    example: Option<ExampleSpec>,
 501    enqueued_at: Instant,
 502    last_edit_at: Instant,
 503}
 504
 505struct RegisteredBuffer {
 506    file: Option<Arc<dyn File>>,
 507    snapshot: TextBufferSnapshot,
 508    pending_predictions: Vec<PendingSettledPrediction>,
 509    last_position: Option<Anchor>,
 510    _subscriptions: [gpui::Subscription; 2],
 511}
 512
 513#[derive(Clone)]
 514struct LastEvent {
 515    old_snapshot: TextBufferSnapshot,
 516    new_snapshot: TextBufferSnapshot,
 517    old_file: Option<Arc<dyn File>>,
 518    new_file: Option<Arc<dyn File>>,
 519    edit_range: Option<Range<Anchor>>,
 520    predicted: bool,
 521    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 522    last_edit_time: Option<Instant>,
 523}
 524
 525impl LastEvent {
 526    pub fn finalize(
 527        &self,
 528        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 529        cx: &App,
 530    ) -> Option<StoredEvent> {
 531        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 532        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 533
 534        let in_open_source_repo =
 535            [self.new_file.as_ref(), self.old_file.as_ref()]
 536                .iter()
 537                .all(|file| {
 538                    file.is_some_and(|file| {
 539                        license_detection_watchers
 540                            .get(&file.worktree_id(cx))
 541                            .is_some_and(|watcher| watcher.is_project_open_source())
 542                    })
 543                });
 544
 545        let (diff, edit_range) =
 546            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 547
 548        if path == old_path && diff.is_empty() {
 549            None
 550        } else {
 551            Some(StoredEvent {
 552                event: Arc::new(zeta_prompt::Event::BufferChange {
 553                    old_path,
 554                    path,
 555                    diff,
 556                    in_open_source_repo,
 557                    predicted: self.predicted,
 558                }),
 559                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 560                    ..self.new_snapshot.anchor_before(edit_range.end),
 561                old_snapshot: self.old_snapshot.clone(),
 562            })
 563        }
 564    }
 565
 566    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 567        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 568            return (self.clone(), None);
 569        };
 570
 571        let before = LastEvent {
 572            old_snapshot: self.old_snapshot.clone(),
 573            new_snapshot: boundary_snapshot.clone(),
 574            old_file: self.old_file.clone(),
 575            new_file: self.new_file.clone(),
 576            edit_range: None,
 577            predicted: self.predicted,
 578            snapshot_after_last_editing_pause: None,
 579            last_edit_time: self.last_edit_time,
 580        };
 581
 582        let after = LastEvent {
 583            old_snapshot: boundary_snapshot.clone(),
 584            new_snapshot: self.new_snapshot.clone(),
 585            old_file: self.old_file.clone(),
 586            new_file: self.new_file.clone(),
 587            edit_range: None,
 588            predicted: self.predicted,
 589            snapshot_after_last_editing_pause: None,
 590            last_edit_time: self.last_edit_time,
 591        };
 592
 593        (before, Some(after))
 594    }
 595}
 596
 597pub(crate) fn compute_diff_between_snapshots(
 598    old_snapshot: &TextBufferSnapshot,
 599    new_snapshot: &TextBufferSnapshot,
 600) -> Option<(String, Range<Point>)> {
 601    let edits: Vec<Edit<usize>> = new_snapshot
 602        .edits_since::<usize>(&old_snapshot.version)
 603        .collect();
 604
 605    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 606
 607    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 608    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 609    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 610    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 611
 612    const CONTEXT_LINES: u32 = 3;
 613
 614    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 615    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 616    let old_context_end_row =
 617        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 618    let new_context_end_row =
 619        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 620
 621    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 622    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 623    let old_end_line_offset = old_snapshot
 624        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 625    let new_end_line_offset = new_snapshot
 626        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 627    let old_edit_range = old_start_line_offset..old_end_line_offset;
 628    let new_edit_range = new_start_line_offset..new_end_line_offset;
 629
 630    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 631    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 632
 633    let diff = language::unified_diff_with_offsets(
 634        &old_region_text,
 635        &new_region_text,
 636        old_context_start_row,
 637        new_context_start_row,
 638    );
 639
 640    Some((diff, new_start_point..new_end_point))
 641}
 642
 643fn buffer_path_with_id_fallback(
 644    file: Option<&Arc<dyn File>>,
 645    snapshot: &TextBufferSnapshot,
 646    cx: &App,
 647) -> Arc<Path> {
 648    if let Some(file) = file {
 649        file.full_path(cx).into()
 650    } else {
 651        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 652    }
 653}
 654
 655impl EditPredictionStore {
 656    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 657        cx.try_global::<EditPredictionStoreGlobal>()
 658            .map(|global| global.0.clone())
 659    }
 660
 661    pub fn global(
 662        client: &Arc<Client>,
 663        user_store: &Entity<UserStore>,
 664        cx: &mut App,
 665    ) -> Entity<Self> {
 666        cx.try_global::<EditPredictionStoreGlobal>()
 667            .map(|global| global.0.clone())
 668            .unwrap_or_else(|| {
 669                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 670                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 671                ep_store
 672            })
 673    }
 674
 675    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 676        let data_collection_choice = Self::load_data_collection_choice();
 677
 678        let llm_token = LlmApiToken::global(cx);
 679
 680        let (reject_tx, reject_rx) = mpsc::unbounded();
 681        cx.background_spawn({
 682            let client = client.clone();
 683            let llm_token = llm_token.clone();
 684            let app_version = AppVersion::global(cx);
 685            let background_executor = cx.background_executor().clone();
 686            async move {
 687                Self::handle_rejected_predictions(
 688                    reject_rx,
 689                    client,
 690                    llm_token,
 691                    app_version,
 692                    background_executor,
 693                )
 694                .await
 695            }
 696        })
 697        .detach();
 698
 699        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 700        cx.spawn(async move |this, cx| {
 701            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 702        })
 703        .detach();
 704
 705        let mut current_user = user_store.read(cx).watch_current_user();
 706        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 707            while current_user.borrow().is_none() {
 708                current_user.next().await;
 709            }
 710            this.update(cx, |this, cx| {
 711                this.refresh_available_experiments(cx);
 712            })
 713            .log_err();
 714        });
 715
 716        let this = Self {
 717            projects: HashMap::default(),
 718            client,
 719            user_store,
 720            llm_token,
 721            _fetch_experiments_task: fetch_experiments_task,
 722            update_required: false,
 723            edit_prediction_model: EditPredictionModel::Zeta,
 724            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 725            preferred_experiment: None,
 726            available_experiments: Vec::new(),
 727            sweep_ai: SweepAi::new(cx),
 728            mercury: Mercury::new(cx),
 729
 730            data_collection_choice,
 731            reject_predictions_tx: reject_tx,
 732            settled_predictions_tx,
 733            rated_predictions: Default::default(),
 734            shown_predictions: Default::default(),
 735            #[cfg(test)]
 736            settled_event_callback: None,
 737        };
 738
 739        this
 740    }
 741
 742    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 743        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 744        let format = ZetaFormat::parse(&version_str).ok()?;
 745        let model_id = env::var("ZED_ZETA_MODEL").ok();
 746        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 747        Some(Zeta2RawConfig {
 748            model_id,
 749            environment,
 750            format,
 751        })
 752    }
 753
 754    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 755        self.edit_prediction_model = model;
 756    }
 757
 758    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 759        self.zeta2_raw_config = Some(config);
 760    }
 761
 762    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 763        self.zeta2_raw_config.as_ref()
 764    }
 765
 766    pub fn preferred_experiment(&self) -> Option<&str> {
 767        self.preferred_experiment.as_deref()
 768    }
 769
 770    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 771        self.preferred_experiment = experiment;
 772    }
 773
 774    pub fn available_experiments(&self) -> &[String] {
 775        &self.available_experiments
 776    }
 777
 778    pub fn active_experiment(&self) -> Option<&str> {
 779        self.preferred_experiment.as_deref().or_else(|| {
 780            self.shown_predictions
 781                .iter()
 782                .find_map(|p| p.model_version.as_ref())
 783                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 784        })
 785    }
 786
 787    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 788        let client = self.client.clone();
 789        let llm_token = self.llm_token.clone();
 790        let app_version = AppVersion::global(cx);
 791        let organization_id = self
 792            .user_store
 793            .read(cx)
 794            .current_organization()
 795            .map(|organization| organization.id.clone());
 796
 797        cx.spawn(async move |this, cx| {
 798            let experiments = cx
 799                .background_spawn(async move {
 800                    let http_client = client.http_client();
 801                    let token = llm_token.acquire(&client, organization_id).await?;
 802                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 803                    let request = http_client::Request::builder()
 804                        .method(Method::GET)
 805                        .uri(url.as_ref())
 806                        .header("Authorization", format!("Bearer {}", token))
 807                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 808                        .body(Default::default())?;
 809                    let mut response = http_client.send(request).await?;
 810                    if response.status().is_success() {
 811                        let mut body = Vec::new();
 812                        response.body_mut().read_to_end(&mut body).await?;
 813                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 814                        Ok(experiments)
 815                    } else {
 816                        let mut body = String::new();
 817                        response.body_mut().read_to_string(&mut body).await?;
 818                        anyhow::bail!(
 819                            "Failed to fetch experiments: {:?}\nBody: {}",
 820                            response.status(),
 821                            body
 822                        );
 823                    }
 824                })
 825                .await?;
 826            this.update(cx, |this, cx| {
 827                this.available_experiments = experiments;
 828                cx.notify();
 829            })?;
 830            anyhow::Ok(())
 831        })
 832        .detach_and_log_err(cx);
 833    }
 834
 835    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 836        use ui::IconName;
 837        match self.edit_prediction_model {
 838            EditPredictionModel::Sweep => {
 839                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 840                    .with_disabled(IconName::SweepAiDisabled)
 841                    .with_up(IconName::SweepAiUp)
 842                    .with_down(IconName::SweepAiDown)
 843                    .with_error(IconName::SweepAiError)
 844            }
 845            EditPredictionModel::Mercury => {
 846                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 847            }
 848            EditPredictionModel::Zeta => {
 849                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 850                    .with_disabled(IconName::ZedPredictDisabled)
 851                    .with_up(IconName::ZedPredictUp)
 852                    .with_down(IconName::ZedPredictDown)
 853                    .with_error(IconName::ZedPredictError)
 854            }
 855            EditPredictionModel::Fim { .. } => {
 856                let settings = &all_language_settings(None, cx).edit_predictions;
 857                match settings.provider {
 858                    EditPredictionProvider::Ollama => {
 859                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 860                    }
 861                    _ => {
 862                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 863                    }
 864                }
 865            }
 866        }
 867    }
 868
 869    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 870        self.sweep_ai.api_token.read(cx).has_key()
 871    }
 872
 873    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 874        self.mercury.api_token.read(cx).has_key()
 875    }
 876
 877    pub fn clear_history(&mut self) {
 878        for project_state in self.projects.values_mut() {
 879            project_state.events.clear();
 880            project_state.last_event.take();
 881        }
 882    }
 883
 884    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 885        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 886            project_state.events.clear();
 887            project_state.last_event.take();
 888        }
 889    }
 890
 891    pub fn edit_history_for_project(
 892        &self,
 893        project: &Entity<Project>,
 894        cx: &App,
 895    ) -> Vec<StoredEvent> {
 896        self.projects
 897            .get(&project.entity_id())
 898            .map(|project_state| project_state.events(cx))
 899            .unwrap_or_default()
 900    }
 901
 902    pub fn context_for_project<'a>(
 903        &'a self,
 904        project: &Entity<Project>,
 905        cx: &'a mut App,
 906    ) -> Vec<RelatedFile> {
 907        self.projects
 908            .get(&project.entity_id())
 909            .map(|project_state| {
 910                project_state.context.update(cx, |context, cx| {
 911                    context
 912                        .related_files_with_buffers(cx)
 913                        .map(|(mut related_file, buffer)| {
 914                            related_file.in_open_source_repo = buffer
 915                                .read(cx)
 916                                .file()
 917                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 918                            related_file
 919                        })
 920                        .collect()
 921                })
 922            })
 923            .unwrap_or_default()
 924    }
 925
 926    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 927        self.projects
 928            .get(&project.entity_id())
 929            .and_then(|project| project.copilot.clone())
 930    }
 931
 932    pub fn start_copilot_for_project(
 933        &mut self,
 934        project: &Entity<Project>,
 935        cx: &mut Context<Self>,
 936    ) -> Option<Entity<Copilot>> {
 937        if DisableAiSettings::get(None, cx).disable_ai {
 938            return None;
 939        }
 940        let state = self.get_or_init_project(project, cx);
 941
 942        if state.copilot.is_some() {
 943            return state.copilot.clone();
 944        }
 945        let _project = project.clone();
 946        let project = project.read(cx);
 947
 948        let node = project.node_runtime().cloned();
 949        if let Some(node) = node {
 950            let next_id = project.languages().next_language_server_id();
 951            let fs = project.fs().clone();
 952
 953            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 954            state.copilot = Some(copilot.clone());
 955            Some(copilot)
 956        } else {
 957            None
 958        }
 959    }
 960
 961    pub fn context_for_project_with_buffers<'a>(
 962        &'a self,
 963        project: &Entity<Project>,
 964        cx: &'a mut App,
 965    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 966        self.projects
 967            .get(&project.entity_id())
 968            .map(|project| {
 969                project.context.update(cx, |context, cx| {
 970                    context.related_files_with_buffers(cx).collect()
 971                })
 972            })
 973            .unwrap_or_default()
 974    }
 975
 976    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 977        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
 978            self.user_store.read(cx).edit_prediction_usage()
 979        } else {
 980            None
 981        }
 982    }
 983
 984    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 985        self.get_or_init_project(project, cx);
 986    }
 987
 988    pub fn register_buffer(
 989        &mut self,
 990        buffer: &Entity<Buffer>,
 991        project: &Entity<Project>,
 992        cx: &mut Context<Self>,
 993    ) {
 994        let project_state = self.get_or_init_project(project, cx);
 995        Self::register_buffer_impl(project_state, buffer, project, cx);
 996    }
 997
 998    fn get_or_init_project(
 999        &mut self,
1000        project: &Entity<Project>,
1001        cx: &mut Context<Self>,
1002    ) -> &mut ProjectState {
1003        let entity_id = project.entity_id();
1004        self.projects
1005            .entry(entity_id)
1006            .or_insert_with(|| ProjectState {
1007                context: {
1008                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1009                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1010                        this.handle_excerpt_store_event(entity_id, event);
1011                    })
1012                    .detach();
1013                    related_excerpt_store
1014                },
1015                events: VecDeque::new(),
1016                last_event: None,
1017                recent_paths: VecDeque::new(),
1018                debug_tx: None,
1019                registered_buffers: HashMap::default(),
1020                current_prediction: None,
1021                cancelled_predictions: HashSet::default(),
1022                pending_predictions: ArrayVec::new(),
1023                next_pending_prediction_id: 0,
1024                last_edit_prediction_refresh: None,
1025                last_jump_prediction_refresh: None,
1026                license_detection_watchers: HashMap::default(),
1027                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1028                _subscriptions: [
1029                    cx.subscribe(&project, Self::handle_project_event),
1030                    cx.observe_release(&project, move |this, _, cx| {
1031                        this.projects.remove(&entity_id);
1032                        cx.notify();
1033                    }),
1034                ],
1035                copilot: None,
1036            })
1037    }
1038
1039    pub fn remove_project(&mut self, project: &Entity<Project>) {
1040        self.projects.remove(&project.entity_id());
1041    }
1042
1043    fn handle_excerpt_store_event(
1044        &mut self,
1045        project_entity_id: EntityId,
1046        event: &RelatedExcerptStoreEvent,
1047    ) {
1048        if let Some(project_state) = self.projects.get(&project_entity_id) {
1049            if let Some(debug_tx) = project_state.debug_tx.clone() {
1050                match event {
1051                    RelatedExcerptStoreEvent::StartedRefresh => {
1052                        debug_tx
1053                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1054                                ContextRetrievalStartedDebugEvent {
1055                                    project_entity_id: project_entity_id,
1056                                    timestamp: Instant::now(),
1057                                    search_prompt: String::new(),
1058                                },
1059                            ))
1060                            .ok();
1061                    }
1062                    RelatedExcerptStoreEvent::FinishedRefresh {
1063                        cache_hit_count,
1064                        cache_miss_count,
1065                        mean_definition_latency,
1066                        max_definition_latency,
1067                    } => {
1068                        debug_tx
1069                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1070                                ContextRetrievalFinishedDebugEvent {
1071                                    project_entity_id: project_entity_id,
1072                                    timestamp: Instant::now(),
1073                                    metadata: vec![
1074                                        (
1075                                            "Cache Hits",
1076                                            format!(
1077                                                "{}/{}",
1078                                                cache_hit_count,
1079                                                cache_hit_count + cache_miss_count
1080                                            )
1081                                            .into(),
1082                                        ),
1083                                        (
1084                                            "Max LSP Time",
1085                                            format!("{} ms", max_definition_latency.as_millis())
1086                                                .into(),
1087                                        ),
1088                                        (
1089                                            "Mean LSP Time",
1090                                            format!("{} ms", mean_definition_latency.as_millis())
1091                                                .into(),
1092                                        ),
1093                                    ],
1094                                },
1095                            ))
1096                            .ok();
1097                    }
1098                }
1099            }
1100        }
1101    }
1102
1103    pub fn debug_info(
1104        &mut self,
1105        project: &Entity<Project>,
1106        cx: &mut Context<Self>,
1107    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1108        let project_state = self.get_or_init_project(project, cx);
1109        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1110        project_state.debug_tx = Some(debug_watch_tx);
1111        debug_watch_rx
1112    }
1113
1114    fn handle_project_event(
1115        &mut self,
1116        project: Entity<Project>,
1117        event: &project::Event,
1118        cx: &mut Context<Self>,
1119    ) {
1120        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1121            return;
1122        }
1123        // TODO [zeta2] init with recent paths
1124        match event {
1125            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1126                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1127                    return;
1128                };
1129                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1130                if let Some(path) = path {
1131                    if let Some(ix) = project_state
1132                        .recent_paths
1133                        .iter()
1134                        .position(|probe| probe == &path)
1135                    {
1136                        project_state.recent_paths.remove(ix);
1137                    }
1138                    project_state.recent_paths.push_front(path);
1139                }
1140            }
1141            project::Event::DiagnosticsUpdated { .. } => {
1142                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1143                    self.refresh_prediction_from_diagnostics(
1144                        project,
1145                        DiagnosticSearchScope::Global,
1146                        cx,
1147                    );
1148                }
1149            }
1150            _ => (),
1151        }
1152    }
1153
1154    fn register_buffer_impl<'a>(
1155        project_state: &'a mut ProjectState,
1156        buffer: &Entity<Buffer>,
1157        project: &Entity<Project>,
1158        cx: &mut Context<Self>,
1159    ) -> &'a mut RegisteredBuffer {
1160        let buffer_id = buffer.entity_id();
1161
1162        if let Some(file) = buffer.read(cx).file() {
1163            let worktree_id = file.worktree_id(cx);
1164            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1165                project_state
1166                    .license_detection_watchers
1167                    .entry(worktree_id)
1168                    .or_insert_with(|| {
1169                        let project_entity_id = project.entity_id();
1170                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1171                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1172                            else {
1173                                return;
1174                            };
1175                            project_state
1176                                .license_detection_watchers
1177                                .remove(&worktree_id);
1178                        })
1179                        .detach();
1180                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1181                    });
1182            }
1183        }
1184
1185        match project_state.registered_buffers.entry(buffer_id) {
1186            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1187            hash_map::Entry::Vacant(entry) => {
1188                let buf = buffer.read(cx);
1189                let snapshot = buf.text_snapshot();
1190                let file = buf.file().cloned();
1191                let project_entity_id = project.entity_id();
1192                entry.insert(RegisteredBuffer {
1193                    snapshot,
1194                    file,
1195                    last_position: None,
1196                    pending_predictions: Vec::new(),
1197                    _subscriptions: [
1198                        cx.subscribe(buffer, {
1199                            let project = project.downgrade();
1200                            move |this, buffer, event, cx| {
1201                                if let language::BufferEvent::Edited { .. } = event
1202                                    && let Some(project) = project.upgrade()
1203                                {
1204                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1205                                }
1206                            }
1207                        }),
1208                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1209                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1210                            else {
1211                                return;
1212                            };
1213                            project_state.registered_buffers.remove(&buffer_id);
1214                        }),
1215                    ],
1216                })
1217            }
1218        }
1219    }
1220
1221    fn report_changes_for_buffer(
1222        &mut self,
1223        buffer: &Entity<Buffer>,
1224        project: &Entity<Project>,
1225        is_predicted: bool,
1226        cx: &mut Context<Self>,
1227    ) {
1228        let project_state = self.get_or_init_project(project, cx);
1229        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1230
1231        let buf = buffer.read(cx);
1232        let new_file = buf.file().cloned();
1233        let new_snapshot = buf.text_snapshot();
1234        if new_snapshot.version == registered_buffer.snapshot.version {
1235            return;
1236        }
1237
1238        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1239        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1240        let mut num_edits = 0usize;
1241        let mut total_deleted = 0usize;
1242        let mut total_inserted = 0usize;
1243        let mut edit_range: Option<Range<Anchor>> = None;
1244        let mut last_offset: Option<usize> = None;
1245        let now = cx.background_executor().now();
1246
1247        for (edit, anchor_range) in
1248            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1249        {
1250            num_edits += 1;
1251            total_deleted += edit.old.len();
1252            total_inserted += edit.new.len();
1253            edit_range = Some(match edit_range {
1254                None => anchor_range,
1255                Some(acc) => acc.start..anchor_range.end,
1256            });
1257            last_offset = Some(edit.new.end);
1258        }
1259
1260        let Some(edit_range) = edit_range else {
1261            return;
1262        };
1263
1264        for pending_prediction in &mut registered_buffer.pending_predictions {
1265            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1266                pending_prediction.last_edit_at = now;
1267            }
1268        }
1269
1270        let action_type = match (total_deleted, total_inserted, num_edits) {
1271            (0, ins, n) if ins == n => UserActionType::InsertChar,
1272            (0, _, _) => UserActionType::InsertSelection,
1273            (del, 0, n) if del == n => UserActionType::DeleteChar,
1274            (_, 0, _) => UserActionType::DeleteSelection,
1275            (_, ins, n) if ins == n => UserActionType::InsertChar,
1276            (_, _, _) => UserActionType::InsertSelection,
1277        };
1278
1279        if let Some(offset) = last_offset {
1280            let point = new_snapshot.offset_to_point(offset);
1281            let timestamp_epoch_ms = SystemTime::now()
1282                .duration_since(UNIX_EPOCH)
1283                .map(|d| d.as_millis() as u64)
1284                .unwrap_or(0);
1285            project_state.record_user_action(UserActionRecord {
1286                action_type,
1287                buffer_id: buffer.entity_id(),
1288                line_number: point.row,
1289                offset,
1290                timestamp_epoch_ms,
1291            });
1292        }
1293
1294        let events = &mut project_state.events;
1295
1296        if let Some(last_event) = project_state.last_event.as_mut() {
1297            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1298                == last_event.new_snapshot.remote_id()
1299                && old_snapshot.version == last_event.new_snapshot.version;
1300
1301            let prediction_source_changed = is_predicted != last_event.predicted;
1302
1303            let should_coalesce = is_next_snapshot_of_same_buffer
1304                && !prediction_source_changed
1305                && last_event
1306                    .edit_range
1307                    .as_ref()
1308                    .is_some_and(|last_edit_range| {
1309                        lines_between_ranges(
1310                            &edit_range.to_point(&new_snapshot),
1311                            &last_edit_range.to_point(&new_snapshot),
1312                        ) <= CHANGE_GROUPING_LINE_SPAN
1313                    });
1314
1315            if should_coalesce {
1316                let pause_elapsed = last_event
1317                    .last_edit_time
1318                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1319                    .unwrap_or(false);
1320                if pause_elapsed {
1321                    last_event.snapshot_after_last_editing_pause =
1322                        Some(last_event.new_snapshot.clone());
1323                }
1324
1325                last_event.edit_range = Some(edit_range);
1326                last_event.new_snapshot = new_snapshot;
1327                last_event.last_edit_time = Some(now);
1328                return;
1329            }
1330        }
1331
1332        if let Some(event) = project_state.last_event.take() {
1333            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1334                if events.len() + 1 >= EVENT_COUNT_MAX {
1335                    events.pop_front();
1336                }
1337                events.push_back(event);
1338            }
1339        }
1340
1341        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1342
1343        project_state.last_event = Some(LastEvent {
1344            old_file,
1345            new_file,
1346            old_snapshot,
1347            new_snapshot,
1348            edit_range: Some(edit_range),
1349            predicted: is_predicted,
1350            snapshot_after_last_editing_pause: None,
1351            last_edit_time: Some(now),
1352        });
1353    }
1354
1355    fn prediction_at(
1356        &mut self,
1357        buffer: &Entity<Buffer>,
1358        position: Option<language::Anchor>,
1359        project: &Entity<Project>,
1360        cx: &App,
1361    ) -> Option<BufferEditPrediction<'_>> {
1362        let project_state = self.projects.get_mut(&project.entity_id())?;
1363        if let Some(position) = position
1364            && let Some(buffer) = project_state
1365                .registered_buffers
1366                .get_mut(&buffer.entity_id())
1367        {
1368            buffer.last_position = Some(position);
1369        }
1370
1371        let CurrentEditPrediction {
1372            requested_by,
1373            prediction,
1374            ..
1375        } = project_state.current_prediction.as_ref()?;
1376
1377        if prediction.targets_buffer(buffer.read(cx)) {
1378            Some(BufferEditPrediction::Local { prediction })
1379        } else {
1380            let show_jump = match requested_by {
1381                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1382                    requested_by_buffer_id == &buffer.entity_id()
1383                }
1384                PredictionRequestedBy::DiagnosticsUpdate => true,
1385            };
1386
1387            if show_jump {
1388                Some(BufferEditPrediction::Jump { prediction })
1389            } else {
1390                None
1391            }
1392        }
1393    }
1394
1395    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1396        let Some(current_prediction) = self
1397            .projects
1398            .get_mut(&project.entity_id())
1399            .and_then(|project_state| project_state.current_prediction.take())
1400        else {
1401            return;
1402        };
1403
1404        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1405
1406        // can't hold &mut project_state ref across report_changes_for_buffer_call
1407        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1408            return;
1409        };
1410
1411        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1412            project_state.cancel_pending_prediction(pending_prediction, cx);
1413        }
1414
1415        match self.edit_prediction_model {
1416            EditPredictionModel::Sweep => {
1417                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1418            }
1419            EditPredictionModel::Mercury => {
1420                mercury::edit_prediction_accepted(
1421                    current_prediction.prediction.id,
1422                    self.client.http_client(),
1423                    cx,
1424                );
1425            }
1426            EditPredictionModel::Zeta => {
1427                let is_cloud = !matches!(
1428                    all_language_settings(None, cx).edit_predictions.provider,
1429                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1430                );
1431                if is_cloud {
1432                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1433                }
1434            }
1435            EditPredictionModel::Fim { .. } => {}
1436        }
1437    }
1438
1439    async fn handle_rejected_predictions(
1440        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1441        client: Arc<Client>,
1442        llm_token: LlmApiToken,
1443        app_version: Version,
1444        background_executor: BackgroundExecutor,
1445    ) {
1446        let mut rx = std::pin::pin!(rx.peekable());
1447        let mut batched = Vec::new();
1448
1449        while let Some(EditPredictionRejectionPayload {
1450            rejection,
1451            organization_id,
1452        }) = rx.next().await
1453        {
1454            batched.push(rejection);
1455
1456            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1457                select_biased! {
1458                    next = rx.as_mut().peek().fuse() => {
1459                        if next.is_some() {
1460                            continue;
1461                        }
1462                    }
1463                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1464                }
1465            }
1466
1467            let url = client
1468                .http_client()
1469                .build_zed_llm_url("/predict_edits/reject", &[])
1470                .unwrap();
1471
1472            let flush_count = batched
1473                .len()
1474                // in case items have accumulated after failure
1475                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1476            let start = batched.len() - flush_count;
1477
1478            let body = RejectEditPredictionsBodyRef {
1479                rejections: &batched[start..],
1480            };
1481
1482            let result = Self::send_api_request::<()>(
1483                |builder| {
1484                    let req = builder
1485                        .uri(url.as_ref())
1486                        .body(serde_json::to_string(&body)?.into());
1487                    anyhow::Ok(req?)
1488                },
1489                client.clone(),
1490                llm_token.clone(),
1491                organization_id,
1492                app_version.clone(),
1493                true,
1494            )
1495            .await;
1496
1497            if result.log_err().is_some() {
1498                batched.drain(start..);
1499            }
1500        }
1501    }
1502
1503    async fn run_settled_predictions_worker(
1504        this: WeakEntity<Self>,
1505        mut rx: UnboundedReceiver<Instant>,
1506        cx: &mut AsyncApp,
1507    ) {
1508        let mut next_wake_time: Option<Instant> = None;
1509        loop {
1510            let now = cx.background_executor().now();
1511            if let Some(wake_time) = next_wake_time.take() {
1512                cx.background_executor()
1513                    .timer(wake_time.duration_since(now))
1514                    .await;
1515            } else {
1516                let Some(new_enqueue_time) = rx.next().await else {
1517                    break;
1518                };
1519                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1520                while rx.next().now_or_never().flatten().is_some() {}
1521                continue;
1522            }
1523
1524            let Some(this) = this.upgrade() else {
1525                break;
1526            };
1527
1528            let now = cx.background_executor().now();
1529
1530            let mut oldest_edited_at = None;
1531
1532            this.update(cx, |this, _| {
1533                for (_, project_state) in this.projects.iter_mut() {
1534                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1535                        registered_buffer
1536                            .pending_predictions
1537                            .retain_mut(|pending_prediction| {
1538                                let age =
1539                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1540                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1541                                    return false;
1542                                }
1543
1544                                let quiet_for =
1545                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1546                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1547                                    let settled_editable_region = registered_buffer
1548                                        .snapshot
1549                                        .text_for_range(
1550                                            pending_prediction.editable_anchor_range.clone(),
1551                                        )
1552                                        .collect::<String>();
1553
1554                                    #[cfg(test)]
1555                                    if let Some(callback) = &this.settled_event_callback {
1556                                        callback(
1557                                            pending_prediction.request_id.clone(),
1558                                            settled_editable_region.clone(),
1559                                        );
1560                                    }
1561
1562                                    telemetry::event!(
1563                                        EDIT_PREDICTION_SETTLED_EVENT,
1564                                        request_id = pending_prediction.request_id.0.clone(),
1565                                        settled_editable_region,
1566                                        example = pending_prediction.example.take(),
1567                                    );
1568
1569                                    return false;
1570                                }
1571
1572                                if oldest_edited_at
1573                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1574                                {
1575                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1576                                }
1577
1578                                true
1579                            });
1580                    }
1581                }
1582            });
1583
1584            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1585        }
1586    }
1587
1588    pub(crate) fn enqueue_settled_prediction(
1589        &mut self,
1590        request_id: EditPredictionId,
1591        project: &Entity<Project>,
1592        edited_buffer: &Entity<Buffer>,
1593        edited_buffer_snapshot: &BufferSnapshot,
1594        editable_offset_range: Range<usize>,
1595        example: Option<ExampleSpec>,
1596        cx: &mut Context<Self>,
1597    ) {
1598        let this = &mut *self;
1599        let project_state = this.get_or_init_project(project, cx);
1600        if let Some(buffer) = project_state
1601            .registered_buffers
1602            .get_mut(&edited_buffer.entity_id())
1603        {
1604            let now = cx.background_executor().now();
1605            buffer.pending_predictions.push(PendingSettledPrediction {
1606                request_id: request_id,
1607                editable_anchor_range: edited_buffer_snapshot
1608                    .anchor_range_around(editable_offset_range),
1609                example,
1610                enqueued_at: now,
1611                last_edit_at: now,
1612            });
1613            this.settled_predictions_tx.unbounded_send(now).ok();
1614        }
1615    }
1616
1617    fn reject_current_prediction(
1618        &mut self,
1619        reason: EditPredictionRejectReason,
1620        project: &Entity<Project>,
1621        cx: &App,
1622    ) {
1623        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1624            project_state.pending_predictions.clear();
1625            if let Some(prediction) = project_state.current_prediction.take() {
1626                let model_version = prediction.prediction.model_version.clone();
1627                self.reject_prediction(
1628                    prediction.prediction.id,
1629                    reason,
1630                    prediction.was_shown,
1631                    model_version,
1632                    cx,
1633                );
1634            }
1635        };
1636    }
1637
1638    fn did_show_current_prediction(
1639        &mut self,
1640        project: &Entity<Project>,
1641        display_type: edit_prediction_types::SuggestionDisplayType,
1642        cx: &mut Context<Self>,
1643    ) {
1644        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1645            return;
1646        };
1647
1648        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1649            return;
1650        };
1651
1652        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1653        let previous_shown_with = current_prediction.shown_with;
1654
1655        if previous_shown_with.is_none() || !is_jump {
1656            current_prediction.shown_with = Some(display_type);
1657        }
1658
1659        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1660
1661        if is_first_non_jump_show {
1662            current_prediction.was_shown = true;
1663        }
1664
1665        let display_type_changed = previous_shown_with != Some(display_type);
1666
1667        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1668            sweep_ai::edit_prediction_shown(
1669                &self.sweep_ai,
1670                self.client.clone(),
1671                &current_prediction.prediction,
1672                display_type,
1673                cx,
1674            );
1675        }
1676
1677        if is_first_non_jump_show {
1678            self.shown_predictions
1679                .push_front(current_prediction.prediction.clone());
1680            if self.shown_predictions.len() > 50 {
1681                let completion = self.shown_predictions.pop_back().unwrap();
1682                self.rated_predictions.remove(&completion.id);
1683            }
1684        }
1685    }
1686
1687    fn reject_prediction(
1688        &mut self,
1689        prediction_id: EditPredictionId,
1690        reason: EditPredictionRejectReason,
1691        was_shown: bool,
1692        model_version: Option<String>,
1693        cx: &App,
1694    ) {
1695        match self.edit_prediction_model {
1696            EditPredictionModel::Zeta => {
1697                let is_cloud = !matches!(
1698                    all_language_settings(None, cx).edit_predictions.provider,
1699                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1700                );
1701
1702                if is_cloud {
1703                    let organization_id = self
1704                        .user_store
1705                        .read(cx)
1706                        .current_organization()
1707                        .map(|organization| organization.id.clone());
1708
1709                    self.reject_predictions_tx
1710                        .unbounded_send(EditPredictionRejectionPayload {
1711                            rejection: EditPredictionRejection {
1712                                request_id: prediction_id.to_string(),
1713                                reason,
1714                                was_shown,
1715                                model_version,
1716                            },
1717                            organization_id,
1718                        })
1719                        .log_err();
1720                }
1721            }
1722            EditPredictionModel::Mercury => {
1723                mercury::edit_prediction_rejected(
1724                    prediction_id,
1725                    was_shown,
1726                    reason,
1727                    self.client.http_client(),
1728                    cx,
1729                );
1730            }
1731            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1732        }
1733    }
1734
1735    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1736        self.projects
1737            .get(&project.entity_id())
1738            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1739    }
1740
1741    pub fn refresh_prediction_from_buffer(
1742        &mut self,
1743        project: Entity<Project>,
1744        buffer: Entity<Buffer>,
1745        position: language::Anchor,
1746        cx: &mut Context<Self>,
1747    ) {
1748        self.queue_prediction_refresh(
1749            project.clone(),
1750            PredictEditsRequestTrigger::Other,
1751            buffer.entity_id(),
1752            cx,
1753            move |this, cx| {
1754                let Some(request_task) = this
1755                    .update(cx, |this, cx| {
1756                        this.request_prediction(
1757                            &project,
1758                            &buffer,
1759                            position,
1760                            PredictEditsRequestTrigger::Other,
1761                            cx,
1762                        )
1763                    })
1764                    .log_err()
1765                else {
1766                    return Task::ready(anyhow::Ok(None));
1767                };
1768
1769                cx.spawn(async move |_cx| {
1770                    request_task.await.map(|prediction_result| {
1771                        prediction_result.map(|prediction_result| {
1772                            (
1773                                prediction_result,
1774                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1775                            )
1776                        })
1777                    })
1778                })
1779            },
1780        )
1781    }
1782
1783    pub fn refresh_prediction_from_diagnostics(
1784        &mut self,
1785        project: Entity<Project>,
1786        scope: DiagnosticSearchScope,
1787        cx: &mut Context<Self>,
1788    ) {
1789        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1790            return;
1791        }
1792
1793        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1794            return;
1795        };
1796
1797        // Prefer predictions from buffer
1798        if project_state.current_prediction.is_some() {
1799            log::debug!(
1800                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1801            );
1802            return;
1803        }
1804
1805        self.queue_prediction_refresh(
1806            project.clone(),
1807            PredictEditsRequestTrigger::Diagnostics,
1808            project.entity_id(),
1809            cx,
1810            move |this, cx| {
1811                let Some((active_buffer, snapshot, cursor_point)) = this
1812                    .read_with(cx, |this, cx| {
1813                        let project_state = this.projects.get(&project.entity_id())?;
1814                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1815                        let snapshot = buffer.read(cx).snapshot();
1816
1817                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1818                            return None;
1819                        }
1820
1821                        let cursor_point = position
1822                            .map(|pos| pos.to_point(&snapshot))
1823                            .unwrap_or_default();
1824
1825                        Some((buffer, snapshot, cursor_point))
1826                    })
1827                    .log_err()
1828                    .flatten()
1829                else {
1830                    return Task::ready(anyhow::Ok(None));
1831                };
1832
1833                cx.spawn(async move |cx| {
1834                    let diagnostic_search_range = match scope {
1835                        DiagnosticSearchScope::Local => {
1836                            let diagnostic_search_start =
1837                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1838                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1839                            Point::new(diagnostic_search_start, 0)
1840                                ..Point::new(diagnostic_search_end, 0)
1841                        }
1842                        DiagnosticSearchScope::Global => Default::default(),
1843                    };
1844
1845                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1846                        active_buffer,
1847                        &snapshot,
1848                        diagnostic_search_range,
1849                        cursor_point,
1850                        &project,
1851                        cx,
1852                    )
1853                    .await?
1854                    else {
1855                        return anyhow::Ok(None);
1856                    };
1857
1858                    let Some(prediction_result) = this
1859                        .update(cx, |this, cx| {
1860                            this.request_prediction(
1861                                &project,
1862                                &jump_buffer,
1863                                jump_position,
1864                                PredictEditsRequestTrigger::Diagnostics,
1865                                cx,
1866                            )
1867                        })?
1868                        .await?
1869                    else {
1870                        return anyhow::Ok(None);
1871                    };
1872
1873                    this.update(cx, |this, cx| {
1874                        Some((
1875                            if this
1876                                .get_or_init_project(&project, cx)
1877                                .current_prediction
1878                                .is_none()
1879                            {
1880                                prediction_result
1881                            } else {
1882                                EditPredictionResult {
1883                                    id: prediction_result.id,
1884                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1885                                }
1886                            },
1887                            PredictionRequestedBy::DiagnosticsUpdate,
1888                        ))
1889                    })
1890                })
1891            },
1892        );
1893    }
1894
1895    fn predictions_enabled_at(
1896        snapshot: &BufferSnapshot,
1897        position: Option<language::Anchor>,
1898        cx: &App,
1899    ) -> bool {
1900        let file = snapshot.file();
1901        let all_settings = all_language_settings(file, cx);
1902        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1903            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1904        {
1905            return false;
1906        }
1907
1908        if let Some(last_position) = position {
1909            let settings = snapshot.settings_at(last_position, cx);
1910
1911            if !settings.edit_predictions_disabled_in.is_empty()
1912                && let Some(scope) = snapshot.language_scope_at(last_position)
1913                && let Some(scope_name) = scope.override_name()
1914                && settings
1915                    .edit_predictions_disabled_in
1916                    .iter()
1917                    .any(|s| s == scope_name)
1918            {
1919                return false;
1920            }
1921        }
1922
1923        true
1924    }
1925
1926    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1927}
1928
1929fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1930    match provider {
1931        EditPredictionProvider::Zed
1932        | EditPredictionProvider::Sweep
1933        | EditPredictionProvider::Mercury
1934        | EditPredictionProvider::Ollama
1935        | EditPredictionProvider::OpenAiCompatibleApi
1936        | EditPredictionProvider::Experimental(_) => true,
1937        EditPredictionProvider::None
1938        | EditPredictionProvider::Copilot
1939        | EditPredictionProvider::Codestral => false,
1940    }
1941}
1942
1943impl EditPredictionStore {
1944    fn queue_prediction_refresh(
1945        &mut self,
1946        project: Entity<Project>,
1947        request_trigger: PredictEditsRequestTrigger,
1948        throttle_entity: EntityId,
1949        cx: &mut Context<Self>,
1950        do_refresh: impl FnOnce(
1951            WeakEntity<Self>,
1952            &mut AsyncApp,
1953        )
1954            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1955        + 'static,
1956    ) {
1957        fn select_throttle(
1958            project_state: &mut ProjectState,
1959            request_trigger: PredictEditsRequestTrigger,
1960        ) -> &mut Option<(EntityId, Instant)> {
1961            match request_trigger {
1962                PredictEditsRequestTrigger::Diagnostics => {
1963                    &mut project_state.last_jump_prediction_refresh
1964                }
1965                _ => &mut project_state.last_edit_prediction_refresh,
1966            }
1967        }
1968
1969        let (needs_acceptance_tracking, max_pending_predictions) =
1970            match all_language_settings(None, cx).edit_predictions.provider {
1971                EditPredictionProvider::Zed
1972                | EditPredictionProvider::Sweep
1973                | EditPredictionProvider::Mercury
1974                | EditPredictionProvider::Experimental(_) => (true, 2),
1975                EditPredictionProvider::Ollama => (false, 1),
1976                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1977                EditPredictionProvider::None
1978                | EditPredictionProvider::Copilot
1979                | EditPredictionProvider::Codestral => {
1980                    log::error!("queue_prediction_refresh called with non-store provider");
1981                    return;
1982                }
1983            };
1984
1985        let drop_on_cancel = !needs_acceptance_tracking;
1986        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1987        let project_state = self.get_or_init_project(&project, cx);
1988        let pending_prediction_id = project_state.next_pending_prediction_id;
1989        project_state.next_pending_prediction_id += 1;
1990        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
1991
1992        let task = cx.spawn(async move |this, cx| {
1993            let throttle_wait = this
1994                .update(cx, |this, cx| {
1995                    let project_state = this.get_or_init_project(&project, cx);
1996                    let throttle = *select_throttle(project_state, request_trigger);
1997
1998                    throttle.and_then(|(last_entity, last_timestamp)| {
1999                        if throttle_entity != last_entity {
2000                            return None;
2001                        }
2002                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2003                    })
2004                })
2005                .ok()
2006                .flatten();
2007
2008            if let Some(timeout) = throttle_wait {
2009                cx.background_executor().timer(timeout).await;
2010            }
2011
2012            // If this task was cancelled before the throttle timeout expired,
2013            // do not perform a request. Also skip if another task already
2014            // proceeded since we were enqueued (duplicate).
2015            let mut is_cancelled = true;
2016            this.update(cx, |this, cx| {
2017                let project_state = this.get_or_init_project(&project, cx);
2018                let was_cancelled = project_state
2019                    .cancelled_predictions
2020                    .remove(&pending_prediction_id);
2021                if was_cancelled {
2022                    return;
2023                }
2024
2025                // Another request has been already sent since this was enqueued
2026                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2027                    return;
2028                }
2029
2030                let new_refresh = (throttle_entity, Instant::now());
2031                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2032                is_cancelled = false;
2033            })
2034            .ok();
2035            if is_cancelled {
2036                return None;
2037            }
2038
2039            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2040            let new_prediction_id = new_prediction_result
2041                .as_ref()
2042                .map(|(prediction, _)| prediction.id.clone());
2043
2044            // When a prediction completes, remove it from the pending list, and cancel
2045            // any pending predictions that were enqueued before it.
2046            this.update(cx, |this, cx| {
2047                let project_state = this.get_or_init_project(&project, cx);
2048
2049                let is_cancelled = project_state
2050                    .cancelled_predictions
2051                    .remove(&pending_prediction_id);
2052
2053                let new_current_prediction = if !is_cancelled
2054                    && let Some((prediction_result, requested_by)) = new_prediction_result
2055                {
2056                    match prediction_result.prediction {
2057                        Ok(prediction) => {
2058                            let new_prediction = CurrentEditPrediction {
2059                                requested_by,
2060                                prediction,
2061                                was_shown: false,
2062                                shown_with: None,
2063                            };
2064
2065                            if let Some(current_prediction) =
2066                                project_state.current_prediction.as_ref()
2067                            {
2068                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2069                                {
2070                                    this.reject_current_prediction(
2071                                        EditPredictionRejectReason::Replaced,
2072                                        &project,
2073                                        cx,
2074                                    );
2075
2076                                    Some(new_prediction)
2077                                } else {
2078                                    this.reject_prediction(
2079                                        new_prediction.prediction.id,
2080                                        EditPredictionRejectReason::CurrentPreferred,
2081                                        false,
2082                                        new_prediction.prediction.model_version,
2083                                        cx,
2084                                    );
2085                                    None
2086                                }
2087                            } else {
2088                                Some(new_prediction)
2089                            }
2090                        }
2091                        Err(reject_reason) => {
2092                            this.reject_prediction(
2093                                prediction_result.id,
2094                                reject_reason,
2095                                false,
2096                                None,
2097                                cx,
2098                            );
2099                            None
2100                        }
2101                    }
2102                } else {
2103                    None
2104                };
2105
2106                let project_state = this.get_or_init_project(&project, cx);
2107
2108                if let Some(new_prediction) = new_current_prediction {
2109                    project_state.current_prediction = Some(new_prediction);
2110                }
2111
2112                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2113                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2114                    if pending_prediction.id == pending_prediction_id {
2115                        pending_predictions.remove(ix);
2116                        for pending_prediction in pending_predictions.drain(0..ix) {
2117                            project_state.cancel_pending_prediction(pending_prediction, cx)
2118                        }
2119                        break;
2120                    }
2121                }
2122                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2123                cx.notify();
2124            })
2125            .ok();
2126
2127            new_prediction_id
2128        });
2129
2130        if project_state.pending_predictions.len() < max_pending_predictions {
2131            project_state.pending_predictions.push(PendingPrediction {
2132                id: pending_prediction_id,
2133                task,
2134                drop_on_cancel,
2135            });
2136        } else {
2137            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2138            project_state.pending_predictions.push(PendingPrediction {
2139                id: pending_prediction_id,
2140                task,
2141                drop_on_cancel,
2142            });
2143            project_state.cancel_pending_prediction(pending_prediction, cx);
2144        }
2145    }
2146
2147    pub fn request_prediction(
2148        &mut self,
2149        project: &Entity<Project>,
2150        active_buffer: &Entity<Buffer>,
2151        position: language::Anchor,
2152        trigger: PredictEditsRequestTrigger,
2153        cx: &mut Context<Self>,
2154    ) -> Task<Result<Option<EditPredictionResult>>> {
2155        self.request_prediction_internal(
2156            project.clone(),
2157            active_buffer.clone(),
2158            position,
2159            trigger,
2160            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2161            cx,
2162        )
2163    }
2164
2165    fn request_prediction_internal(
2166        &mut self,
2167        project: Entity<Project>,
2168        active_buffer: Entity<Buffer>,
2169        position: language::Anchor,
2170        trigger: PredictEditsRequestTrigger,
2171        allow_jump: bool,
2172        cx: &mut Context<Self>,
2173    ) -> Task<Result<Option<EditPredictionResult>>> {
2174        self.get_or_init_project(&project, cx);
2175        let project_state = self.projects.get(&project.entity_id()).unwrap();
2176        let stored_events = project_state.events(cx);
2177        let has_events = !stored_events.is_empty();
2178        let events: Vec<Arc<zeta_prompt::Event>> =
2179            stored_events.iter().map(|e| e.event.clone()).collect();
2180        let debug_tx = project_state.debug_tx.clone();
2181
2182        let snapshot = active_buffer.read(cx).snapshot();
2183        let cursor_point = position.to_point(&snapshot);
2184        let current_offset = position.to_offset(&snapshot);
2185
2186        let mut user_actions: Vec<UserActionRecord> =
2187            project_state.user_actions.iter().cloned().collect();
2188
2189        if let Some(last_action) = user_actions.last() {
2190            if last_action.buffer_id == active_buffer.entity_id()
2191                && current_offset != last_action.offset
2192            {
2193                let timestamp_epoch_ms = SystemTime::now()
2194                    .duration_since(UNIX_EPOCH)
2195                    .map(|d| d.as_millis() as u64)
2196                    .unwrap_or(0);
2197                user_actions.push(UserActionRecord {
2198                    action_type: UserActionType::CursorMovement,
2199                    buffer_id: active_buffer.entity_id(),
2200                    line_number: cursor_point.row,
2201                    offset: current_offset,
2202                    timestamp_epoch_ms,
2203                });
2204            }
2205        }
2206        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2207        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2208        let diagnostic_search_range =
2209            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2210
2211        let related_files = self.context_for_project(&project, cx);
2212
2213        let is_open_source = snapshot
2214            .file()
2215            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2216            && events.iter().all(|event| event.in_open_source_repo())
2217            && related_files.iter().all(|file| file.in_open_source_repo);
2218
2219        let can_collect_data = !cfg!(test)
2220            && is_open_source
2221            && self.is_data_collection_enabled(cx)
2222            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2223
2224        let recent_paths = project_state.recent_paths.clone();
2225
2226        let inputs = EditPredictionModelInput {
2227            project: project.clone(),
2228            buffer: active_buffer,
2229            snapshot,
2230            position,
2231            events,
2232            related_files,
2233            recent_paths,
2234            trigger,
2235            diagnostic_search_range: diagnostic_search_range,
2236            debug_tx,
2237            user_actions,
2238            can_collect_data,
2239            is_open_source,
2240        };
2241
2242        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2243
2244        let task = match self.edit_prediction_model {
2245            EditPredictionModel::Zeta => {
2246                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2247            }
2248            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2249            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2250            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2251        };
2252
2253        cx.spawn(async move |this, cx| {
2254            let prediction = task.await?;
2255
2256            // Only fall back to diagnostics-based prediction if we got a
2257            // the model had nothing to suggest for the buffer
2258            if prediction.is_none()
2259                && allow_jump
2260                && has_events
2261                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2262            {
2263                this.update(cx, |this, cx| {
2264                    this.refresh_prediction_from_diagnostics(
2265                        project,
2266                        DiagnosticSearchScope::Local,
2267                        cx,
2268                    );
2269                })?;
2270                return anyhow::Ok(None);
2271            }
2272
2273            Ok(prediction)
2274        })
2275    }
2276
2277    pub(crate) async fn next_diagnostic_location(
2278        active_buffer: Entity<Buffer>,
2279        active_buffer_snapshot: &BufferSnapshot,
2280        active_buffer_diagnostic_search_range: Range<Point>,
2281        active_buffer_cursor_point: Point,
2282        project: &Entity<Project>,
2283        cx: &mut AsyncApp,
2284    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2285        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2286            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2287            .flat_map(|(_, _, _, selections)| {
2288                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2289            })
2290            .collect();
2291
2292        let mut jump_location = active_buffer_snapshot
2293            .diagnostic_groups(None)
2294            .into_iter()
2295            .filter_map(|(_, group)| {
2296                let range = &group.entries[group.primary_ix]
2297                    .range
2298                    .to_point(&active_buffer_snapshot);
2299                if range.overlaps(&active_buffer_diagnostic_search_range) {
2300                    return None;
2301                }
2302                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2303                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2304                });
2305                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2306                    <= DIAGNOSTIC_LINES_RANGE;
2307                if near_collaborator && !near_local {
2308                    return None;
2309                }
2310                Some(range.start)
2311            })
2312            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2313            .map(|position| {
2314                (
2315                    active_buffer.clone(),
2316                    active_buffer_snapshot.anchor_before(position),
2317                )
2318            });
2319
2320        if jump_location.is_none() {
2321            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2322                let file = buffer.file()?;
2323
2324                Some(ProjectPath {
2325                    worktree_id: file.worktree_id(cx),
2326                    path: file.path().clone(),
2327                })
2328            });
2329
2330            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2331                project
2332                    .diagnostic_summaries(false, cx)
2333                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2334                    .map(|(path, _, _)| {
2335                        let shared_prefix = path
2336                            .path
2337                            .components()
2338                            .zip(
2339                                active_buffer_path
2340                                    .as_ref()
2341                                    .map(|p| p.path.components())
2342                                    .unwrap_or_default(),
2343                            )
2344                            .take_while(|(a, b)| a == b)
2345                            .count();
2346                        (path, shared_prefix)
2347                    })
2348                    .collect()
2349            });
2350
2351            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2352
2353            for (path, _) in candidates {
2354                let candidate_buffer = project
2355                    .update(cx, |project, cx| project.open_buffer(path, cx))
2356                    .await?;
2357
2358                let (has_collaborators, diagnostic_position) =
2359                    candidate_buffer.read_with(cx, |buffer, _cx| {
2360                        let snapshot = buffer.snapshot();
2361                        let has_collaborators = snapshot
2362                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2363                            .next()
2364                            .is_some();
2365                        let position = buffer
2366                            .buffer_diagnostics(None)
2367                            .into_iter()
2368                            .min_by_key(|entry| entry.diagnostic.severity)
2369                            .map(|entry| entry.range.start);
2370                        (has_collaborators, position)
2371                    });
2372
2373                if has_collaborators {
2374                    continue;
2375                }
2376
2377                if let Some(position) = diagnostic_position {
2378                    jump_location = Some((candidate_buffer, position));
2379                    break;
2380                }
2381            }
2382        }
2383
2384        anyhow::Ok(jump_location)
2385    }
2386
2387    async fn send_raw_llm_request(
2388        request: RawCompletionRequest,
2389        client: Arc<Client>,
2390        custom_url: Option<Arc<Url>>,
2391        llm_token: LlmApiToken,
2392        organization_id: Option<OrganizationId>,
2393        app_version: Version,
2394    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2395        let url = if let Some(custom_url) = custom_url {
2396            custom_url.as_ref().clone()
2397        } else {
2398            client
2399                .http_client()
2400                .build_zed_llm_url("/predict_edits/raw", &[])?
2401        };
2402
2403        Self::send_api_request(
2404            |builder| {
2405                let req = builder
2406                    .uri(url.as_ref())
2407                    .body(serde_json::to_string(&request)?.into());
2408                Ok(req?)
2409            },
2410            client,
2411            llm_token,
2412            organization_id,
2413            app_version,
2414            true,
2415        )
2416        .await
2417    }
2418
2419    pub(crate) async fn send_v3_request(
2420        input: ZetaPromptInput,
2421        client: Arc<Client>,
2422        llm_token: LlmApiToken,
2423        organization_id: Option<OrganizationId>,
2424        app_version: Version,
2425        trigger: PredictEditsRequestTrigger,
2426    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2427        let url = client
2428            .http_client()
2429            .build_zed_llm_url("/predict_edits/v3", &[])?;
2430
2431        let request = PredictEditsV3Request { input, trigger };
2432
2433        let json_bytes = serde_json::to_vec(&request)?;
2434        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2435
2436        Self::send_api_request(
2437            |builder| {
2438                let req = builder
2439                    .uri(url.as_ref())
2440                    .header("Content-Encoding", "zstd")
2441                    .body(compressed.clone().into());
2442                Ok(req?)
2443            },
2444            client,
2445            llm_token,
2446            organization_id,
2447            app_version,
2448            true,
2449        )
2450        .await
2451    }
2452
2453    async fn send_api_request<Res>(
2454        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2455        client: Arc<Client>,
2456        llm_token: LlmApiToken,
2457        organization_id: Option<OrganizationId>,
2458        app_version: Version,
2459        require_auth: bool,
2460    ) -> Result<(Res, Option<EditPredictionUsage>)>
2461    where
2462        Res: DeserializeOwned,
2463    {
2464        let http_client = client.http_client();
2465
2466        let mut token = if require_auth {
2467            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2468        } else {
2469            llm_token
2470                .acquire(&client, organization_id.clone())
2471                .await
2472                .ok()
2473        };
2474        let mut did_retry = false;
2475
2476        loop {
2477            let request_builder = http_client::Request::builder().method(Method::POST);
2478
2479            let mut request_builder = request_builder
2480                .header("Content-Type", "application/json")
2481                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2482
2483            // Only add Authorization header if we have a token
2484            if let Some(ref token_value) = token {
2485                request_builder =
2486                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2487            }
2488
2489            let request = build(request_builder)?;
2490
2491            let mut response = http_client.send(request).await?;
2492
2493            if let Some(minimum_required_version) = response
2494                .headers()
2495                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2496                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2497            {
2498                anyhow::ensure!(
2499                    app_version >= minimum_required_version,
2500                    ZedUpdateRequiredError {
2501                        minimum_version: minimum_required_version
2502                    }
2503                );
2504            }
2505
2506            if response.status().is_success() {
2507                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2508
2509                let mut body = Vec::new();
2510                response.body_mut().read_to_end(&mut body).await?;
2511                return Ok((serde_json::from_slice(&body)?, usage));
2512            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2513                did_retry = true;
2514                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2515            } else {
2516                let mut body = String::new();
2517                response.body_mut().read_to_string(&mut body).await?;
2518                anyhow::bail!(
2519                    "Request failed with status: {:?}\nBody: {}",
2520                    response.status(),
2521                    body
2522                );
2523            }
2524        }
2525    }
2526
2527    pub fn refresh_context(
2528        &mut self,
2529        project: &Entity<Project>,
2530        buffer: &Entity<language::Buffer>,
2531        cursor_position: language::Anchor,
2532        cx: &mut Context<Self>,
2533    ) {
2534        self.get_or_init_project(project, cx)
2535            .context
2536            .update(cx, |store, cx| {
2537                store.refresh(buffer.clone(), cursor_position, cx);
2538            });
2539    }
2540
2541    #[cfg(feature = "cli-support")]
2542    pub fn set_context_for_buffer(
2543        &mut self,
2544        project: &Entity<Project>,
2545        related_files: Vec<RelatedFile>,
2546        cx: &mut Context<Self>,
2547    ) {
2548        self.get_or_init_project(project, cx)
2549            .context
2550            .update(cx, |store, cx| {
2551                store.set_related_files(related_files, cx);
2552            });
2553    }
2554
2555    #[cfg(feature = "cli-support")]
2556    pub fn set_recent_paths_for_project(
2557        &mut self,
2558        project: &Entity<Project>,
2559        paths: impl IntoIterator<Item = project::ProjectPath>,
2560        cx: &mut Context<Self>,
2561    ) {
2562        let project_state = self.get_or_init_project(project, cx);
2563        project_state.recent_paths = paths.into_iter().collect();
2564    }
2565
2566    fn is_file_open_source(
2567        &self,
2568        project: &Entity<Project>,
2569        file: &Arc<dyn File>,
2570        cx: &App,
2571    ) -> bool {
2572        if !file.is_local() || file.is_private() {
2573            return false;
2574        }
2575        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2576            return false;
2577        };
2578        project_state
2579            .license_detection_watchers
2580            .get(&file.worktree_id(cx))
2581            .as_ref()
2582            .is_some_and(|watcher| watcher.is_project_open_source())
2583    }
2584
2585    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2586        self.data_collection_choice.is_enabled(cx)
2587    }
2588
2589    fn load_data_collection_choice() -> DataCollectionChoice {
2590        let choice = KEY_VALUE_STORE
2591            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2592            .log_err()
2593            .flatten();
2594
2595        match choice.as_deref() {
2596            Some("true") => DataCollectionChoice::Enabled,
2597            Some("false") => DataCollectionChoice::Disabled,
2598            Some(_) => {
2599                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2600                DataCollectionChoice::NotAnswered
2601            }
2602            None => DataCollectionChoice::NotAnswered,
2603        }
2604    }
2605
2606    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2607        self.data_collection_choice = self.data_collection_choice.toggle();
2608        let new_choice = self.data_collection_choice;
2609        let is_enabled = new_choice.is_enabled(cx);
2610        db::write_and_log(cx, move || {
2611            KEY_VALUE_STORE.write_kvp(
2612                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2613                is_enabled.to_string(),
2614            )
2615        });
2616    }
2617
2618    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2619        self.shown_predictions.iter()
2620    }
2621
2622    pub fn shown_completions_len(&self) -> usize {
2623        self.shown_predictions.len()
2624    }
2625
2626    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2627        self.rated_predictions.contains(id)
2628    }
2629
2630    pub fn rate_prediction(
2631        &mut self,
2632        prediction: &EditPrediction,
2633        rating: EditPredictionRating,
2634        feedback: String,
2635        cx: &mut Context<Self>,
2636    ) {
2637        let organization = self.user_store.read(cx).current_organization();
2638
2639        self.rated_predictions.insert(prediction.id.clone());
2640
2641        cx.background_spawn({
2642            let client = self.client.clone();
2643            let prediction_id = prediction.id.to_string();
2644            let inputs = serde_json::to_value(&prediction.inputs);
2645            let output = prediction
2646                .edit_preview
2647                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2648            async move {
2649                client
2650                    .cloud_client()
2651                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2652                        organization_id: organization.map(|organization| organization.id.clone()),
2653                        request_id: prediction_id,
2654                        rating: match rating {
2655                            EditPredictionRating::Positive => "positive".to_string(),
2656                            EditPredictionRating::Negative => "negative".to_string(),
2657                        },
2658                        inputs: inputs?,
2659                        output,
2660                        feedback,
2661                    })
2662                    .await?;
2663
2664                anyhow::Ok(())
2665            }
2666        })
2667        .detach_and_log_err(cx);
2668
2669        cx.notify();
2670    }
2671}
2672
2673fn merge_trailing_events_if_needed(
2674    events: &mut VecDeque<StoredEvent>,
2675    end_snapshot: &TextBufferSnapshot,
2676    latest_snapshot: &TextBufferSnapshot,
2677    latest_edit_range: &Range<Anchor>,
2678) {
2679    if let Some(last_event) = events.back() {
2680        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2681            return;
2682        }
2683    }
2684
2685    let mut next_old_event = None;
2686    let mut mergeable_count = 0;
2687    for old_event in events.iter().rev() {
2688        if let Some(next_old_event) = &next_old_event
2689            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2690        {
2691            break;
2692        }
2693        mergeable_count += 1;
2694        next_old_event = Some(old_event);
2695    }
2696
2697    if mergeable_count <= 1 {
2698        return;
2699    }
2700
2701    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2702    let oldest_event = events_to_merge.peek().unwrap();
2703    let oldest_snapshot = oldest_event.old_snapshot.clone();
2704
2705    if let Some((diff, edited_range)) =
2706        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2707    {
2708        let merged_event = match oldest_event.event.as_ref() {
2709            zeta_prompt::Event::BufferChange {
2710                old_path,
2711                path,
2712                in_open_source_repo,
2713                ..
2714            } => StoredEvent {
2715                event: Arc::new(zeta_prompt::Event::BufferChange {
2716                    old_path: old_path.clone(),
2717                    path: path.clone(),
2718                    diff,
2719                    in_open_source_repo: *in_open_source_repo,
2720                    predicted: events_to_merge.all(|e| {
2721                        matches!(
2722                            e.event.as_ref(),
2723                            zeta_prompt::Event::BufferChange {
2724                                predicted: true,
2725                                ..
2726                            }
2727                        )
2728                    }),
2729                }),
2730                old_snapshot: oldest_snapshot.clone(),
2731                edit_range: end_snapshot.anchor_before(edited_range.start)
2732                    ..end_snapshot.anchor_before(edited_range.end),
2733            },
2734        };
2735        events.truncate(events.len() - mergeable_count);
2736        events.push_back(merged_event);
2737    }
2738}
2739
2740#[derive(Error, Debug)]
2741#[error(
2742    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2743)]
2744pub struct ZedUpdateRequiredError {
2745    minimum_version: Version,
2746}
2747
2748#[derive(Debug, Clone, Copy)]
2749pub enum DataCollectionChoice {
2750    NotAnswered,
2751    Enabled,
2752    Disabled,
2753}
2754
2755impl DataCollectionChoice {
2756    pub fn is_enabled(self, cx: &App) -> bool {
2757        if cx.is_staff() {
2758            return true;
2759        }
2760        match self {
2761            Self::Enabled => true,
2762            Self::NotAnswered | Self::Disabled => false,
2763        }
2764    }
2765
2766    #[must_use]
2767    pub fn toggle(&self) -> DataCollectionChoice {
2768        match self {
2769            Self::Enabled => Self::Disabled,
2770            Self::Disabled => Self::Enabled,
2771            Self::NotAnswered => Self::Enabled,
2772        }
2773    }
2774}
2775
2776impl From<bool> for DataCollectionChoice {
2777    fn from(value: bool) -> Self {
2778        match value {
2779            true => DataCollectionChoice::Enabled,
2780            false => DataCollectionChoice::Disabled,
2781        }
2782    }
2783}
2784
2785struct ZedPredictUpsell;
2786
2787impl Dismissable for ZedPredictUpsell {
2788    const KEY: &'static str = "dismissed-edit-predict-upsell";
2789
2790    fn dismissed() -> bool {
2791        // To make this backwards compatible with older versions of Zed, we
2792        // check if the user has seen the previous Edit Prediction Onboarding
2793        // before, by checking the data collection choice which was written to
2794        // the database once the user clicked on "Accept and Enable"
2795        if KEY_VALUE_STORE
2796            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2797            .log_err()
2798            .is_some_and(|s| s.is_some())
2799        {
2800            return true;
2801        }
2802
2803        KEY_VALUE_STORE
2804            .read_kvp(Self::KEY)
2805            .log_err()
2806            .is_some_and(|s| s.is_some())
2807    }
2808}
2809
2810pub fn should_show_upsell_modal() -> bool {
2811    !ZedPredictUpsell::dismissed()
2812}
2813
2814pub fn init(cx: &mut App) {
2815    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2816        workspace.register_action(
2817            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2818                ZedPredictModal::toggle(
2819                    workspace,
2820                    workspace.user_store().clone(),
2821                    workspace.client().clone(),
2822                    window,
2823                    cx,
2824                )
2825            },
2826        );
2827
2828        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2829            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2830                settings
2831                    .project
2832                    .all_languages
2833                    .edit_predictions
2834                    .get_or_insert_default()
2835                    .provider = Some(EditPredictionProvider::None)
2836            });
2837        });
2838        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2839            EditPredictionStore::try_global(cx).and_then(|store| {
2840                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2841            })
2842        }
2843
2844        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2845            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2846                copilot_ui::initiate_sign_in(copilot, window, cx);
2847            }
2848        });
2849        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2850            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2851                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2852            }
2853        });
2854        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2855            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2856                copilot_ui::initiate_sign_out(copilot, window, cx);
2857            }
2858        });
2859    })
2860    .detach();
2861}