edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{EditPredictionProvider, Settings as _, update_settings_file};
  39use std::collections::{VecDeque, hash_map};
  40use std::env;
  41use text::Edit;
  42use workspace::Workspace;
  43use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  44
  45use std::mem;
  46use std::ops::Range;
  47use std::path::Path;
  48use std::rc::Rc;
  49use std::str::FromStr as _;
  50use std::sync::Arc;
  51use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  52use thiserror::Error;
  53use util::{RangeExt as _, ResultExt as _};
  54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  55
  56pub mod cursor_excerpt;
  57pub mod example_spec;
  58mod license_detection;
  59pub mod mercury;
  60pub mod ollama;
  61mod onboarding_modal;
  62pub mod open_ai_response;
  63mod prediction;
  64pub mod sweep_ai;
  65
  66pub mod udiff;
  67
  68mod capture_example;
  69mod zed_edit_prediction_delegate;
  70pub mod zeta1;
  71pub mod zeta2;
  72
  73#[cfg(test)]
  74mod edit_prediction_tests;
  75
  76use crate::license_detection::LicenseDetectionWatcher;
  77use crate::mercury::Mercury;
  78use crate::ollama::Ollama;
  79use crate::onboarding_modal::ZedPredictModal;
  80pub use crate::prediction::EditPrediction;
  81pub use crate::prediction::EditPredictionId;
  82use crate::prediction::EditPredictionResult;
  83pub use crate::sweep_ai::SweepAi;
  84pub use capture_example::capture_example;
  85pub use language_model::ApiKeyState;
  86pub use telemetry_events::EditPredictionRating;
  87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 105
 106pub struct Zeta2FeatureFlag;
 107
 108impl FeatureFlag for Zeta2FeatureFlag {
 109    const NAME: &'static str = "zeta2";
 110
 111    fn enabled_for_staff() -> bool {
 112        true
 113    }
 114}
 115
 116#[derive(Clone)]
 117struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 118
 119impl Global for EditPredictionStoreGlobal {}
 120
 121/// Configuration for using the raw Zeta2 endpoint.
 122/// When set, the client uses the raw endpoint and constructs the prompt itself.
 123/// The version is also used as the Baseten environment name (lowercased).
 124#[derive(Clone)]
 125pub struct Zeta2RawConfig {
 126    pub model_id: Option<String>,
 127    pub format: ZetaFormat,
 128}
 129
 130pub struct EditPredictionStore {
 131    client: Arc<Client>,
 132    user_store: Entity<UserStore>,
 133    llm_token: LlmApiToken,
 134    _llm_token_subscription: Subscription,
 135    projects: HashMap<EntityId, ProjectState>,
 136    update_required: bool,
 137    edit_prediction_model: EditPredictionModel,
 138    zeta2_raw_config: Option<Zeta2RawConfig>,
 139    pub sweep_ai: SweepAi,
 140    pub mercury: Mercury,
 141    pub ollama: Ollama,
 142    data_collection_choice: DataCollectionChoice,
 143    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 144    shown_predictions: VecDeque<EditPrediction>,
 145    rated_predictions: HashSet<EditPredictionId>,
 146}
 147
 148#[derive(Copy, Clone, Default, PartialEq, Eq)]
 149pub enum EditPredictionModel {
 150    #[default]
 151    Zeta1,
 152    Zeta2,
 153    Sweep,
 154    Mercury,
 155    Ollama,
 156}
 157
 158#[derive(Clone)]
 159pub struct EditPredictionModelInput {
 160    project: Entity<Project>,
 161    buffer: Entity<Buffer>,
 162    snapshot: BufferSnapshot,
 163    position: Anchor,
 164    events: Vec<Arc<zeta_prompt::Event>>,
 165    related_files: Vec<RelatedFile>,
 166    recent_paths: VecDeque<ProjectPath>,
 167    trigger: PredictEditsRequestTrigger,
 168    diagnostic_search_range: Range<Point>,
 169    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 170    pub user_actions: Vec<UserActionRecord>,
 171}
 172
 173#[derive(Debug)]
 174pub enum DebugEvent {
 175    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 176    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 177    EditPredictionStarted(EditPredictionStartedDebugEvent),
 178    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 179}
 180
 181#[derive(Debug)]
 182pub struct ContextRetrievalStartedDebugEvent {
 183    pub project_entity_id: EntityId,
 184    pub timestamp: Instant,
 185    pub search_prompt: String,
 186}
 187
 188#[derive(Debug)]
 189pub struct ContextRetrievalFinishedDebugEvent {
 190    pub project_entity_id: EntityId,
 191    pub timestamp: Instant,
 192    pub metadata: Vec<(&'static str, SharedString)>,
 193}
 194
 195#[derive(Debug)]
 196pub struct EditPredictionStartedDebugEvent {
 197    pub buffer: WeakEntity<Buffer>,
 198    pub position: Anchor,
 199    pub prompt: Option<String>,
 200}
 201
 202#[derive(Debug)]
 203pub struct EditPredictionFinishedDebugEvent {
 204    pub buffer: WeakEntity<Buffer>,
 205    pub position: Anchor,
 206    pub model_output: Option<String>,
 207}
 208
 209const USER_ACTION_HISTORY_SIZE: usize = 16;
 210
 211#[derive(Clone, Debug)]
 212pub struct UserActionRecord {
 213    pub action_type: UserActionType,
 214    pub buffer_id: EntityId,
 215    pub line_number: u32,
 216    pub offset: usize,
 217    pub timestamp_epoch_ms: u64,
 218}
 219
 220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 221pub enum UserActionType {
 222    InsertChar,
 223    InsertSelection,
 224    DeleteChar,
 225    DeleteSelection,
 226    CursorMovement,
 227}
 228
 229/// An event with associated metadata for reconstructing buffer state.
 230#[derive(Clone)]
 231pub struct StoredEvent {
 232    pub event: Arc<zeta_prompt::Event>,
 233    pub old_snapshot: TextBufferSnapshot,
 234    pub edit_range: Range<Anchor>,
 235}
 236
 237impl StoredEvent {
 238    fn can_merge(
 239        &self,
 240        next_old_event: &&&StoredEvent,
 241        new_snapshot: &TextBufferSnapshot,
 242        last_edit_range: &Range<Anchor>,
 243    ) -> bool {
 244        // Events must be for the same buffer
 245        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 246            return false;
 247        }
 248        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 249            return false;
 250        }
 251
 252        let a_is_predicted = matches!(
 253            self.event.as_ref(),
 254            zeta_prompt::Event::BufferChange {
 255                predicted: true,
 256                ..
 257            }
 258        );
 259        let b_is_predicted = matches!(
 260            next_old_event.event.as_ref(),
 261            zeta_prompt::Event::BufferChange {
 262                predicted: true,
 263                ..
 264            }
 265        );
 266
 267        // If events come from the same source (both predicted or both manual) then
 268        // we would have coalesced them already.
 269        if a_is_predicted == b_is_predicted {
 270            return false;
 271        }
 272
 273        let left_range = self.edit_range.to_point(new_snapshot);
 274        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 275        let latest_range = last_edit_range.to_point(&new_snapshot);
 276
 277        // Events near to the latest edit are not merged if their sources differ.
 278        if lines_between_ranges(&left_range, &latest_range)
 279            .min(lines_between_ranges(&right_range, &latest_range))
 280            <= CHANGE_GROUPING_LINE_SPAN
 281        {
 282            return false;
 283        }
 284
 285        // Events that are distant from each other are not merged.
 286        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 287            return false;
 288        }
 289
 290        true
 291    }
 292}
 293
 294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 295    if left.start > right.end {
 296        return left.start.row - right.end.row;
 297    }
 298    if right.start > left.end {
 299        return right.start.row - left.end.row;
 300    }
 301    0
 302}
 303
 304struct ProjectState {
 305    events: VecDeque<StoredEvent>,
 306    last_event: Option<LastEvent>,
 307    recent_paths: VecDeque<ProjectPath>,
 308    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 309    current_prediction: Option<CurrentEditPrediction>,
 310    next_pending_prediction_id: usize,
 311    pending_predictions: ArrayVec<PendingPrediction, 2>,
 312    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 313    last_prediction_refresh: Option<(EntityId, Instant)>,
 314    cancelled_predictions: HashSet<usize>,
 315    context: Entity<RelatedExcerptStore>,
 316    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 317    user_actions: VecDeque<UserActionRecord>,
 318    _subscriptions: [gpui::Subscription; 2],
 319    copilot: Option<Entity<Copilot>>,
 320}
 321
 322impl ProjectState {
 323    fn record_user_action(&mut self, action: UserActionRecord) {
 324        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 325            self.user_actions.pop_front();
 326        }
 327        self.user_actions.push_back(action);
 328    }
 329
 330    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 331        self.events
 332            .iter()
 333            .cloned()
 334            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 335                let (one, two) = event.split_by_pause();
 336                let one = one.finalize(&self.license_detection_watchers, cx);
 337                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 338                one.into_iter().chain(two)
 339            }))
 340            .collect()
 341    }
 342
 343    fn cancel_pending_prediction(
 344        &mut self,
 345        pending_prediction: PendingPrediction,
 346        cx: &mut Context<EditPredictionStore>,
 347    ) {
 348        self.cancelled_predictions.insert(pending_prediction.id);
 349
 350        if pending_prediction.drop_on_cancel {
 351            drop(pending_prediction.task);
 352        } else {
 353            cx.spawn(async move |this, cx| {
 354                let Some(prediction_id) = pending_prediction.task.await else {
 355                    return;
 356                };
 357
 358                this.update(cx, |this, cx| {
 359                    this.reject_prediction(
 360                        prediction_id,
 361                        EditPredictionRejectReason::Canceled,
 362                        false,
 363                        cx,
 364                    );
 365                })
 366                .ok();
 367            })
 368            .detach()
 369        }
 370    }
 371
 372    fn active_buffer(
 373        &self,
 374        project: &Entity<Project>,
 375        cx: &App,
 376    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 377        let project = project.read(cx);
 378        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 379        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 380        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 381        Some((active_buffer, registered_buffer.last_position))
 382    }
 383}
 384
 385#[derive(Debug, Clone)]
 386struct CurrentEditPrediction {
 387    pub requested_by: PredictionRequestedBy,
 388    pub prediction: EditPrediction,
 389    pub was_shown: bool,
 390    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 391}
 392
 393impl CurrentEditPrediction {
 394    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 395        let Some(new_edits) = self
 396            .prediction
 397            .interpolate(&self.prediction.buffer.read(cx))
 398        else {
 399            return false;
 400        };
 401
 402        if self.prediction.buffer != old_prediction.prediction.buffer {
 403            return true;
 404        }
 405
 406        let Some(old_edits) = old_prediction
 407            .prediction
 408            .interpolate(&old_prediction.prediction.buffer.read(cx))
 409        else {
 410            return true;
 411        };
 412
 413        let requested_by_buffer_id = self.requested_by.buffer_id();
 414
 415        // This reduces the occurrence of UI thrash from replacing edits
 416        //
 417        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 418        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 419            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 420            && old_edits.len() == 1
 421            && new_edits.len() == 1
 422        {
 423            let (old_range, old_text) = &old_edits[0];
 424            let (new_range, new_text) = &new_edits[0];
 425            new_range == old_range && new_text.starts_with(old_text.as_ref())
 426        } else {
 427            true
 428        }
 429    }
 430}
 431
 432#[derive(Debug, Clone)]
 433enum PredictionRequestedBy {
 434    DiagnosticsUpdate,
 435    Buffer(EntityId),
 436}
 437
 438impl PredictionRequestedBy {
 439    pub fn buffer_id(&self) -> Option<EntityId> {
 440        match self {
 441            PredictionRequestedBy::DiagnosticsUpdate => None,
 442            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 443        }
 444    }
 445}
 446
 447#[derive(Debug)]
 448struct PendingPrediction {
 449    id: usize,
 450    task: Task<Option<EditPredictionId>>,
 451    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 452    /// If false, the task is awaited to completion so rejection can be reported.
 453    drop_on_cancel: bool,
 454}
 455
 456/// A prediction from the perspective of a buffer.
 457#[derive(Debug)]
 458enum BufferEditPrediction<'a> {
 459    Local { prediction: &'a EditPrediction },
 460    Jump { prediction: &'a EditPrediction },
 461}
 462
 463#[cfg(test)]
 464impl std::ops::Deref for BufferEditPrediction<'_> {
 465    type Target = EditPrediction;
 466
 467    fn deref(&self) -> &Self::Target {
 468        match self {
 469            BufferEditPrediction::Local { prediction } => prediction,
 470            BufferEditPrediction::Jump { prediction } => prediction,
 471        }
 472    }
 473}
 474
 475struct RegisteredBuffer {
 476    file: Option<Arc<dyn File>>,
 477    snapshot: TextBufferSnapshot,
 478    last_position: Option<Anchor>,
 479    _subscriptions: [gpui::Subscription; 2],
 480}
 481
 482#[derive(Clone)]
 483struct LastEvent {
 484    old_snapshot: TextBufferSnapshot,
 485    new_snapshot: TextBufferSnapshot,
 486    old_file: Option<Arc<dyn File>>,
 487    new_file: Option<Arc<dyn File>>,
 488    edit_range: Option<Range<Anchor>>,
 489    predicted: bool,
 490    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 491    last_edit_time: Option<Instant>,
 492}
 493
 494impl LastEvent {
 495    pub fn finalize(
 496        &self,
 497        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 498        cx: &App,
 499    ) -> Option<StoredEvent> {
 500        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 501        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 502
 503        let in_open_source_repo =
 504            [self.new_file.as_ref(), self.old_file.as_ref()]
 505                .iter()
 506                .all(|file| {
 507                    file.is_some_and(|file| {
 508                        license_detection_watchers
 509                            .get(&file.worktree_id(cx))
 510                            .is_some_and(|watcher| watcher.is_project_open_source())
 511                    })
 512                });
 513
 514        let (diff, edit_range) =
 515            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 516
 517        if path == old_path && diff.is_empty() {
 518            None
 519        } else {
 520            Some(StoredEvent {
 521                event: Arc::new(zeta_prompt::Event::BufferChange {
 522                    old_path,
 523                    path,
 524                    diff,
 525                    in_open_source_repo,
 526                    predicted: self.predicted,
 527                }),
 528                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 529                    ..self.new_snapshot.anchor_before(edit_range.end),
 530                old_snapshot: self.old_snapshot.clone(),
 531            })
 532        }
 533    }
 534
 535    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 536        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 537            return (self.clone(), None);
 538        };
 539
 540        let before = LastEvent {
 541            old_snapshot: self.old_snapshot.clone(),
 542            new_snapshot: boundary_snapshot.clone(),
 543            old_file: self.old_file.clone(),
 544            new_file: self.new_file.clone(),
 545            edit_range: None,
 546            predicted: self.predicted,
 547            snapshot_after_last_editing_pause: None,
 548            last_edit_time: self.last_edit_time,
 549        };
 550
 551        let after = LastEvent {
 552            old_snapshot: boundary_snapshot.clone(),
 553            new_snapshot: self.new_snapshot.clone(),
 554            old_file: self.old_file.clone(),
 555            new_file: self.new_file.clone(),
 556            edit_range: None,
 557            predicted: self.predicted,
 558            snapshot_after_last_editing_pause: None,
 559            last_edit_time: self.last_edit_time,
 560        };
 561
 562        (before, Some(after))
 563    }
 564}
 565
 566pub(crate) fn compute_diff_between_snapshots(
 567    old_snapshot: &TextBufferSnapshot,
 568    new_snapshot: &TextBufferSnapshot,
 569) -> Option<(String, Range<Point>)> {
 570    let edits: Vec<Edit<usize>> = new_snapshot
 571        .edits_since::<usize>(&old_snapshot.version)
 572        .collect();
 573
 574    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 575
 576    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 577    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 578    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 579    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 580
 581    const CONTEXT_LINES: u32 = 3;
 582
 583    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 584    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 585    let old_context_end_row =
 586        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 587    let new_context_end_row =
 588        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 589
 590    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 591    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 592    let old_end_line_offset = old_snapshot
 593        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 594    let new_end_line_offset = new_snapshot
 595        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 596    let old_edit_range = old_start_line_offset..old_end_line_offset;
 597    let new_edit_range = new_start_line_offset..new_end_line_offset;
 598
 599    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 600    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 601
 602    let diff = language::unified_diff_with_offsets(
 603        &old_region_text,
 604        &new_region_text,
 605        old_context_start_row,
 606        new_context_start_row,
 607    );
 608
 609    Some((diff, new_start_point..new_end_point))
 610}
 611
 612fn buffer_path_with_id_fallback(
 613    file: Option<&Arc<dyn File>>,
 614    snapshot: &TextBufferSnapshot,
 615    cx: &App,
 616) -> Arc<Path> {
 617    if let Some(file) = file {
 618        file.full_path(cx).into()
 619    } else {
 620        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 621    }
 622}
 623
 624impl EditPredictionStore {
 625    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 626        cx.try_global::<EditPredictionStoreGlobal>()
 627            .map(|global| global.0.clone())
 628    }
 629
 630    pub fn global(
 631        client: &Arc<Client>,
 632        user_store: &Entity<UserStore>,
 633        cx: &mut App,
 634    ) -> Entity<Self> {
 635        cx.try_global::<EditPredictionStoreGlobal>()
 636            .map(|global| global.0.clone())
 637            .unwrap_or_else(|| {
 638                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 639                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 640                ep_store
 641            })
 642    }
 643
 644    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 645        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 646        let data_collection_choice = Self::load_data_collection_choice();
 647
 648        let llm_token = LlmApiToken::default();
 649
 650        let (reject_tx, reject_rx) = mpsc::unbounded();
 651        cx.background_spawn({
 652            let client = client.clone();
 653            let llm_token = llm_token.clone();
 654            let app_version = AppVersion::global(cx);
 655            let background_executor = cx.background_executor().clone();
 656            async move {
 657                Self::handle_rejected_predictions(
 658                    reject_rx,
 659                    client,
 660                    llm_token,
 661                    app_version,
 662                    background_executor,
 663                )
 664                .await
 665            }
 666        })
 667        .detach();
 668
 669        let this = Self {
 670            projects: HashMap::default(),
 671            client,
 672            user_store,
 673            llm_token,
 674            _llm_token_subscription: cx.subscribe(
 675                &refresh_llm_token_listener,
 676                |this, _listener, _event, cx| {
 677                    let client = this.client.clone();
 678                    let llm_token = this.llm_token.clone();
 679                    cx.spawn(async move |_this, _cx| {
 680                        llm_token.refresh(&client).await?;
 681                        anyhow::Ok(())
 682                    })
 683                    .detach_and_log_err(cx);
 684                },
 685            ),
 686            update_required: false,
 687            edit_prediction_model: EditPredictionModel::Zeta2,
 688            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 689            sweep_ai: SweepAi::new(cx),
 690            mercury: Mercury::new(cx),
 691            ollama: Ollama::new(),
 692
 693            data_collection_choice,
 694            reject_predictions_tx: reject_tx,
 695            rated_predictions: Default::default(),
 696            shown_predictions: Default::default(),
 697        };
 698
 699        this
 700    }
 701
 702    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 703        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 704        let format = ZetaFormat::parse(&version_str).ok()?;
 705        let model_id = env::var("ZED_ZETA_MODEL").ok();
 706        Some(Zeta2RawConfig { model_id, format })
 707    }
 708
 709    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 710        self.edit_prediction_model = model;
 711    }
 712
 713    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 714        self.zeta2_raw_config = Some(config);
 715    }
 716
 717    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 718        self.zeta2_raw_config.as_ref()
 719    }
 720
 721    pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
 722        use ui::IconName;
 723        match self.edit_prediction_model {
 724            EditPredictionModel::Sweep => {
 725                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 726                    .with_disabled(IconName::SweepAiDisabled)
 727                    .with_up(IconName::SweepAiUp)
 728                    .with_down(IconName::SweepAiDown)
 729                    .with_error(IconName::SweepAiError)
 730            }
 731            EditPredictionModel::Mercury => {
 732                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 733            }
 734            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 735                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 736                    .with_disabled(IconName::ZedPredictDisabled)
 737                    .with_up(IconName::ZedPredictUp)
 738                    .with_down(IconName::ZedPredictDown)
 739                    .with_error(IconName::ZedPredictError)
 740            }
 741            EditPredictionModel::Ollama => {
 742                edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 743            }
 744        }
 745    }
 746
 747    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 748        self.sweep_ai.api_token.read(cx).has_key()
 749    }
 750
 751    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 752        self.mercury.api_token.read(cx).has_key()
 753    }
 754
 755    pub fn clear_history(&mut self) {
 756        for project_state in self.projects.values_mut() {
 757            project_state.events.clear();
 758            project_state.last_event.take();
 759        }
 760    }
 761
 762    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 763        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 764            project_state.events.clear();
 765            project_state.last_event.take();
 766        }
 767    }
 768
 769    pub fn edit_history_for_project(
 770        &self,
 771        project: &Entity<Project>,
 772        cx: &App,
 773    ) -> Vec<StoredEvent> {
 774        self.projects
 775            .get(&project.entity_id())
 776            .map(|project_state| project_state.events(cx))
 777            .unwrap_or_default()
 778    }
 779
 780    pub fn context_for_project<'a>(
 781        &'a self,
 782        project: &Entity<Project>,
 783        cx: &'a mut App,
 784    ) -> Vec<RelatedFile> {
 785        self.projects
 786            .get(&project.entity_id())
 787            .map(|project_state| {
 788                project_state.context.update(cx, |context, cx| {
 789                    context
 790                        .related_files_with_buffers(cx)
 791                        .map(|(mut related_file, buffer)| {
 792                            related_file.in_open_source_repo = buffer
 793                                .read(cx)
 794                                .file()
 795                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 796                            related_file
 797                        })
 798                        .collect()
 799                })
 800            })
 801            .unwrap_or_default()
 802    }
 803
 804    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 805        self.projects
 806            .get(&project.entity_id())
 807            .and_then(|project| project.copilot.clone())
 808    }
 809
 810    pub fn start_copilot_for_project(
 811        &mut self,
 812        project: &Entity<Project>,
 813        cx: &mut Context<Self>,
 814    ) -> Option<Entity<Copilot>> {
 815        if DisableAiSettings::get(None, cx).disable_ai {
 816            return None;
 817        }
 818        let state = self.get_or_init_project(project, cx);
 819
 820        if state.copilot.is_some() {
 821            return state.copilot.clone();
 822        }
 823        let _project = project.clone();
 824        let project = project.read(cx);
 825
 826        let node = project.node_runtime().cloned();
 827        if let Some(node) = node {
 828            let next_id = project.languages().next_language_server_id();
 829            let fs = project.fs().clone();
 830
 831            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 832            state.copilot = Some(copilot.clone());
 833            Some(copilot)
 834        } else {
 835            None
 836        }
 837    }
 838
 839    pub fn context_for_project_with_buffers<'a>(
 840        &'a self,
 841        project: &Entity<Project>,
 842        cx: &'a mut App,
 843    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 844        self.projects
 845            .get(&project.entity_id())
 846            .map(|project| {
 847                project.context.update(cx, |context, cx| {
 848                    context.related_files_with_buffers(cx).collect()
 849                })
 850            })
 851            .unwrap_or_default()
 852    }
 853
 854    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 855        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
 856            self.user_store.read(cx).edit_prediction_usage()
 857        } else {
 858            None
 859        }
 860    }
 861
 862    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 863        self.get_or_init_project(project, cx);
 864    }
 865
 866    pub fn register_buffer(
 867        &mut self,
 868        buffer: &Entity<Buffer>,
 869        project: &Entity<Project>,
 870        cx: &mut Context<Self>,
 871    ) {
 872        let project_state = self.get_or_init_project(project, cx);
 873        Self::register_buffer_impl(project_state, buffer, project, cx);
 874    }
 875
 876    fn get_or_init_project(
 877        &mut self,
 878        project: &Entity<Project>,
 879        cx: &mut Context<Self>,
 880    ) -> &mut ProjectState {
 881        let entity_id = project.entity_id();
 882        self.projects
 883            .entry(entity_id)
 884            .or_insert_with(|| ProjectState {
 885                context: {
 886                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 887                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 888                        this.handle_excerpt_store_event(entity_id, event);
 889                    })
 890                    .detach();
 891                    related_excerpt_store
 892                },
 893                events: VecDeque::new(),
 894                last_event: None,
 895                recent_paths: VecDeque::new(),
 896                debug_tx: None,
 897                registered_buffers: HashMap::default(),
 898                current_prediction: None,
 899                cancelled_predictions: HashSet::default(),
 900                pending_predictions: ArrayVec::new(),
 901                next_pending_prediction_id: 0,
 902                last_prediction_refresh: None,
 903                license_detection_watchers: HashMap::default(),
 904                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 905                _subscriptions: [
 906                    cx.subscribe(&project, Self::handle_project_event),
 907                    cx.observe_release(&project, move |this, _, cx| {
 908                        this.projects.remove(&entity_id);
 909                        cx.notify();
 910                    }),
 911                ],
 912                copilot: None,
 913            })
 914    }
 915
 916    pub fn remove_project(&mut self, project: &Entity<Project>) {
 917        self.projects.remove(&project.entity_id());
 918    }
 919
 920    fn handle_excerpt_store_event(
 921        &mut self,
 922        project_entity_id: EntityId,
 923        event: &RelatedExcerptStoreEvent,
 924    ) {
 925        if let Some(project_state) = self.projects.get(&project_entity_id) {
 926            if let Some(debug_tx) = project_state.debug_tx.clone() {
 927                match event {
 928                    RelatedExcerptStoreEvent::StartedRefresh => {
 929                        debug_tx
 930                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 931                                ContextRetrievalStartedDebugEvent {
 932                                    project_entity_id: project_entity_id,
 933                                    timestamp: Instant::now(),
 934                                    search_prompt: String::new(),
 935                                },
 936                            ))
 937                            .ok();
 938                    }
 939                    RelatedExcerptStoreEvent::FinishedRefresh {
 940                        cache_hit_count,
 941                        cache_miss_count,
 942                        mean_definition_latency,
 943                        max_definition_latency,
 944                    } => {
 945                        debug_tx
 946                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 947                                ContextRetrievalFinishedDebugEvent {
 948                                    project_entity_id: project_entity_id,
 949                                    timestamp: Instant::now(),
 950                                    metadata: vec![
 951                                        (
 952                                            "Cache Hits",
 953                                            format!(
 954                                                "{}/{}",
 955                                                cache_hit_count,
 956                                                cache_hit_count + cache_miss_count
 957                                            )
 958                                            .into(),
 959                                        ),
 960                                        (
 961                                            "Max LSP Time",
 962                                            format!("{} ms", max_definition_latency.as_millis())
 963                                                .into(),
 964                                        ),
 965                                        (
 966                                            "Mean LSP Time",
 967                                            format!("{} ms", mean_definition_latency.as_millis())
 968                                                .into(),
 969                                        ),
 970                                    ],
 971                                },
 972                            ))
 973                            .ok();
 974                    }
 975                }
 976            }
 977        }
 978    }
 979
 980    pub fn debug_info(
 981        &mut self,
 982        project: &Entity<Project>,
 983        cx: &mut Context<Self>,
 984    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 985        let project_state = self.get_or_init_project(project, cx);
 986        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 987        project_state.debug_tx = Some(debug_watch_tx);
 988        debug_watch_rx
 989    }
 990
 991    fn handle_project_event(
 992        &mut self,
 993        project: Entity<Project>,
 994        event: &project::Event,
 995        cx: &mut Context<Self>,
 996    ) {
 997        // TODO [zeta2] init with recent paths
 998        match event {
 999            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1000                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1001                    return;
1002                };
1003                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1004                if let Some(path) = path {
1005                    if let Some(ix) = project_state
1006                        .recent_paths
1007                        .iter()
1008                        .position(|probe| probe == &path)
1009                    {
1010                        project_state.recent_paths.remove(ix);
1011                    }
1012                    project_state.recent_paths.push_front(path);
1013                }
1014            }
1015            project::Event::DiagnosticsUpdated { .. } => {
1016                if cx.has_flag::<Zeta2FeatureFlag>() {
1017                    self.refresh_prediction_from_diagnostics(project, cx);
1018                }
1019            }
1020            _ => (),
1021        }
1022    }
1023
1024    fn register_buffer_impl<'a>(
1025        project_state: &'a mut ProjectState,
1026        buffer: &Entity<Buffer>,
1027        project: &Entity<Project>,
1028        cx: &mut Context<Self>,
1029    ) -> &'a mut RegisteredBuffer {
1030        let buffer_id = buffer.entity_id();
1031
1032        if let Some(file) = buffer.read(cx).file() {
1033            let worktree_id = file.worktree_id(cx);
1034            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1035                project_state
1036                    .license_detection_watchers
1037                    .entry(worktree_id)
1038                    .or_insert_with(|| {
1039                        let project_entity_id = project.entity_id();
1040                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1041                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1042                            else {
1043                                return;
1044                            };
1045                            project_state
1046                                .license_detection_watchers
1047                                .remove(&worktree_id);
1048                        })
1049                        .detach();
1050                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1051                    });
1052            }
1053        }
1054
1055        match project_state.registered_buffers.entry(buffer_id) {
1056            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1057            hash_map::Entry::Vacant(entry) => {
1058                let buf = buffer.read(cx);
1059                let snapshot = buf.text_snapshot();
1060                let file = buf.file().cloned();
1061                let project_entity_id = project.entity_id();
1062                entry.insert(RegisteredBuffer {
1063                    snapshot,
1064                    file,
1065                    last_position: None,
1066                    _subscriptions: [
1067                        cx.subscribe(buffer, {
1068                            let project = project.downgrade();
1069                            move |this, buffer, event, cx| {
1070                                if let language::BufferEvent::Edited = event
1071                                    && let Some(project) = project.upgrade()
1072                                {
1073                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1074                                }
1075                            }
1076                        }),
1077                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1078                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1079                            else {
1080                                return;
1081                            };
1082                            project_state.registered_buffers.remove(&buffer_id);
1083                        }),
1084                    ],
1085                })
1086            }
1087        }
1088    }
1089
1090    fn report_changes_for_buffer(
1091        &mut self,
1092        buffer: &Entity<Buffer>,
1093        project: &Entity<Project>,
1094        is_predicted: bool,
1095        cx: &mut Context<Self>,
1096    ) {
1097        let project_state = self.get_or_init_project(project, cx);
1098        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1099
1100        let buf = buffer.read(cx);
1101        let new_file = buf.file().cloned();
1102        let new_snapshot = buf.text_snapshot();
1103        if new_snapshot.version == registered_buffer.snapshot.version {
1104            return;
1105        }
1106
1107        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1108        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1109        let mut num_edits = 0usize;
1110        let mut total_deleted = 0usize;
1111        let mut total_inserted = 0usize;
1112        let mut edit_range: Option<Range<Anchor>> = None;
1113        let mut last_offset: Option<usize> = None;
1114
1115        for (edit, anchor_range) in
1116            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1117        {
1118            num_edits += 1;
1119            total_deleted += edit.old.len();
1120            total_inserted += edit.new.len();
1121            edit_range = Some(match edit_range {
1122                None => anchor_range,
1123                Some(acc) => acc.start..anchor_range.end,
1124            });
1125            last_offset = Some(edit.new.end);
1126        }
1127
1128        let Some(edit_range) = edit_range else {
1129            return;
1130        };
1131
1132        let action_type = match (total_deleted, total_inserted, num_edits) {
1133            (0, ins, n) if ins == n => UserActionType::InsertChar,
1134            (0, _, _) => UserActionType::InsertSelection,
1135            (del, 0, n) if del == n => UserActionType::DeleteChar,
1136            (_, 0, _) => UserActionType::DeleteSelection,
1137            (_, ins, n) if ins == n => UserActionType::InsertChar,
1138            (_, _, _) => UserActionType::InsertSelection,
1139        };
1140
1141        if let Some(offset) = last_offset {
1142            let point = new_snapshot.offset_to_point(offset);
1143            let timestamp_epoch_ms = SystemTime::now()
1144                .duration_since(UNIX_EPOCH)
1145                .map(|d| d.as_millis() as u64)
1146                .unwrap_or(0);
1147            project_state.record_user_action(UserActionRecord {
1148                action_type,
1149                buffer_id: buffer.entity_id(),
1150                line_number: point.row,
1151                offset,
1152                timestamp_epoch_ms,
1153            });
1154        }
1155
1156        let events = &mut project_state.events;
1157
1158        let now = cx.background_executor().now();
1159        if let Some(last_event) = project_state.last_event.as_mut() {
1160            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1161                == last_event.new_snapshot.remote_id()
1162                && old_snapshot.version == last_event.new_snapshot.version;
1163
1164            let prediction_source_changed = is_predicted != last_event.predicted;
1165
1166            let should_coalesce = is_next_snapshot_of_same_buffer
1167                && !prediction_source_changed
1168                && last_event
1169                    .edit_range
1170                    .as_ref()
1171                    .is_some_and(|last_edit_range| {
1172                        lines_between_ranges(
1173                            &edit_range.to_point(&new_snapshot),
1174                            &last_edit_range.to_point(&new_snapshot),
1175                        ) <= CHANGE_GROUPING_LINE_SPAN
1176                    });
1177
1178            if should_coalesce {
1179                let pause_elapsed = last_event
1180                    .last_edit_time
1181                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1182                    .unwrap_or(false);
1183                if pause_elapsed {
1184                    last_event.snapshot_after_last_editing_pause =
1185                        Some(last_event.new_snapshot.clone());
1186                }
1187
1188                last_event.edit_range = Some(edit_range);
1189                last_event.new_snapshot = new_snapshot;
1190                last_event.last_edit_time = Some(now);
1191                return;
1192            }
1193        }
1194
1195        if let Some(event) = project_state.last_event.take() {
1196            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1197                if events.len() + 1 >= EVENT_COUNT_MAX {
1198                    events.pop_front();
1199                }
1200                events.push_back(event);
1201            }
1202        }
1203
1204        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1205
1206        project_state.last_event = Some(LastEvent {
1207            old_file,
1208            new_file,
1209            old_snapshot,
1210            new_snapshot,
1211            edit_range: Some(edit_range),
1212            predicted: is_predicted,
1213            snapshot_after_last_editing_pause: None,
1214            last_edit_time: Some(now),
1215        });
1216    }
1217
1218    fn prediction_at(
1219        &mut self,
1220        buffer: &Entity<Buffer>,
1221        position: Option<language::Anchor>,
1222        project: &Entity<Project>,
1223        cx: &App,
1224    ) -> Option<BufferEditPrediction<'_>> {
1225        let project_state = self.projects.get_mut(&project.entity_id())?;
1226        if let Some(position) = position
1227            && let Some(buffer) = project_state
1228                .registered_buffers
1229                .get_mut(&buffer.entity_id())
1230        {
1231            buffer.last_position = Some(position);
1232        }
1233
1234        let CurrentEditPrediction {
1235            requested_by,
1236            prediction,
1237            ..
1238        } = project_state.current_prediction.as_ref()?;
1239
1240        if prediction.targets_buffer(buffer.read(cx)) {
1241            Some(BufferEditPrediction::Local { prediction })
1242        } else {
1243            let show_jump = match requested_by {
1244                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1245                    requested_by_buffer_id == &buffer.entity_id()
1246                }
1247                PredictionRequestedBy::DiagnosticsUpdate => true,
1248            };
1249
1250            if show_jump {
1251                Some(BufferEditPrediction::Jump { prediction })
1252            } else {
1253                None
1254            }
1255        }
1256    }
1257
1258    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1259        let Some(current_prediction) = self
1260            .projects
1261            .get_mut(&project.entity_id())
1262            .and_then(|project_state| project_state.current_prediction.take())
1263        else {
1264            return;
1265        };
1266
1267        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1268
1269        // can't hold &mut project_state ref across report_changes_for_buffer_call
1270        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1271            return;
1272        };
1273
1274        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1275            project_state.cancel_pending_prediction(pending_prediction, cx);
1276        }
1277
1278        match self.edit_prediction_model {
1279            EditPredictionModel::Sweep => {
1280                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1281            }
1282            EditPredictionModel::Mercury => {
1283                mercury::edit_prediction_accepted(
1284                    current_prediction.prediction.id,
1285                    self.client.http_client(),
1286                    cx,
1287                );
1288            }
1289            EditPredictionModel::Ollama => {}
1290            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1291                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1292            }
1293        }
1294    }
1295
1296    async fn handle_rejected_predictions(
1297        rx: UnboundedReceiver<EditPredictionRejection>,
1298        client: Arc<Client>,
1299        llm_token: LlmApiToken,
1300        app_version: Version,
1301        background_executor: BackgroundExecutor,
1302    ) {
1303        let mut rx = std::pin::pin!(rx.peekable());
1304        let mut batched = Vec::new();
1305
1306        while let Some(rejection) = rx.next().await {
1307            batched.push(rejection);
1308
1309            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1310                select_biased! {
1311                    next = rx.as_mut().peek().fuse() => {
1312                        if next.is_some() {
1313                            continue;
1314                        }
1315                    }
1316                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1317                }
1318            }
1319
1320            let url = client
1321                .http_client()
1322                .build_zed_llm_url("/predict_edits/reject", &[])
1323                .unwrap();
1324
1325            let flush_count = batched
1326                .len()
1327                // in case items have accumulated after failure
1328                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1329            let start = batched.len() - flush_count;
1330
1331            let body = RejectEditPredictionsBodyRef {
1332                rejections: &batched[start..],
1333            };
1334
1335            let result = Self::send_api_request::<()>(
1336                |builder| {
1337                    let req = builder
1338                        .uri(url.as_ref())
1339                        .body(serde_json::to_string(&body)?.into());
1340                    anyhow::Ok(req?)
1341                },
1342                client.clone(),
1343                llm_token.clone(),
1344                app_version.clone(),
1345                true,
1346            )
1347            .await;
1348
1349            if result.log_err().is_some() {
1350                batched.drain(start..);
1351            }
1352        }
1353    }
1354
1355    fn reject_current_prediction(
1356        &mut self,
1357        reason: EditPredictionRejectReason,
1358        project: &Entity<Project>,
1359        cx: &App,
1360    ) {
1361        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1362            project_state.pending_predictions.clear();
1363            if let Some(prediction) = project_state.current_prediction.take() {
1364                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1365            }
1366        };
1367    }
1368
1369    fn did_show_current_prediction(
1370        &mut self,
1371        project: &Entity<Project>,
1372        display_type: edit_prediction_types::SuggestionDisplayType,
1373        cx: &mut Context<Self>,
1374    ) {
1375        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1376            return;
1377        };
1378
1379        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1380            return;
1381        };
1382
1383        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1384        let previous_shown_with = current_prediction.shown_with;
1385
1386        if previous_shown_with.is_none() || !is_jump {
1387            current_prediction.shown_with = Some(display_type);
1388        }
1389
1390        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1391
1392        if is_first_non_jump_show {
1393            current_prediction.was_shown = true;
1394        }
1395
1396        let display_type_changed = previous_shown_with != Some(display_type);
1397
1398        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1399            sweep_ai::edit_prediction_shown(
1400                &self.sweep_ai,
1401                self.client.clone(),
1402                &current_prediction.prediction,
1403                display_type,
1404                cx,
1405            );
1406        }
1407
1408        if is_first_non_jump_show {
1409            self.shown_predictions
1410                .push_front(current_prediction.prediction.clone());
1411            if self.shown_predictions.len() > 50 {
1412                let completion = self.shown_predictions.pop_back().unwrap();
1413                self.rated_predictions.remove(&completion.id);
1414            }
1415        }
1416    }
1417
1418    fn reject_prediction(
1419        &mut self,
1420        prediction_id: EditPredictionId,
1421        reason: EditPredictionRejectReason,
1422        was_shown: bool,
1423        cx: &App,
1424    ) {
1425        match self.edit_prediction_model {
1426            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1427                self.reject_predictions_tx
1428                    .unbounded_send(EditPredictionRejection {
1429                        request_id: prediction_id.to_string(),
1430                        reason,
1431                        was_shown,
1432                    })
1433                    .log_err();
1434            }
1435            EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1436            EditPredictionModel::Mercury => {
1437                mercury::edit_prediction_rejected(
1438                    prediction_id,
1439                    was_shown,
1440                    reason,
1441                    self.client.http_client(),
1442                    cx,
1443                );
1444            }
1445        }
1446    }
1447
1448    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1449        self.projects
1450            .get(&project.entity_id())
1451            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1452    }
1453
1454    pub fn refresh_prediction_from_buffer(
1455        &mut self,
1456        project: Entity<Project>,
1457        buffer: Entity<Buffer>,
1458        position: language::Anchor,
1459        cx: &mut Context<Self>,
1460    ) {
1461        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1462            let Some(request_task) = this
1463                .update(cx, |this, cx| {
1464                    this.request_prediction(
1465                        &project,
1466                        &buffer,
1467                        position,
1468                        PredictEditsRequestTrigger::Other,
1469                        cx,
1470                    )
1471                })
1472                .log_err()
1473            else {
1474                return Task::ready(anyhow::Ok(None));
1475            };
1476
1477            cx.spawn(async move |_cx| {
1478                request_task.await.map(|prediction_result| {
1479                    prediction_result.map(|prediction_result| {
1480                        (
1481                            prediction_result,
1482                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1483                        )
1484                    })
1485                })
1486            })
1487        })
1488    }
1489
1490    pub fn refresh_prediction_from_diagnostics(
1491        &mut self,
1492        project: Entity<Project>,
1493        cx: &mut Context<Self>,
1494    ) {
1495        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1496            return;
1497        };
1498
1499        // Prefer predictions from buffer
1500        if project_state.current_prediction.is_some() {
1501            return;
1502        };
1503
1504        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1505            let Some((active_buffer, snapshot, cursor_point)) = this
1506                .read_with(cx, |this, cx| {
1507                    let project_state = this.projects.get(&project.entity_id())?;
1508                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1509                    let snapshot = buffer.read(cx).snapshot();
1510
1511                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1512                        return None;
1513                    }
1514
1515                    let cursor_point = position
1516                        .map(|pos| pos.to_point(&snapshot))
1517                        .unwrap_or_default();
1518
1519                    Some((buffer, snapshot, cursor_point))
1520                })
1521                .log_err()
1522                .flatten()
1523            else {
1524                return Task::ready(anyhow::Ok(None));
1525            };
1526
1527            cx.spawn(async move |cx| {
1528                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1529                    active_buffer,
1530                    &snapshot,
1531                    Default::default(),
1532                    cursor_point,
1533                    &project,
1534                    cx,
1535                )
1536                .await?
1537                else {
1538                    return anyhow::Ok(None);
1539                };
1540
1541                let Some(prediction_result) = this
1542                    .update(cx, |this, cx| {
1543                        this.request_prediction(
1544                            &project,
1545                            &jump_buffer,
1546                            jump_position,
1547                            PredictEditsRequestTrigger::Diagnostics,
1548                            cx,
1549                        )
1550                    })?
1551                    .await?
1552                else {
1553                    return anyhow::Ok(None);
1554                };
1555
1556                this.update(cx, |this, cx| {
1557                    Some((
1558                        if this
1559                            .get_or_init_project(&project, cx)
1560                            .current_prediction
1561                            .is_none()
1562                        {
1563                            prediction_result
1564                        } else {
1565                            EditPredictionResult {
1566                                id: prediction_result.id,
1567                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1568                            }
1569                        },
1570                        PredictionRequestedBy::DiagnosticsUpdate,
1571                    ))
1572                })
1573            })
1574        });
1575    }
1576
1577    fn predictions_enabled_at(
1578        snapshot: &BufferSnapshot,
1579        position: Option<language::Anchor>,
1580        cx: &App,
1581    ) -> bool {
1582        let file = snapshot.file();
1583        let all_settings = all_language_settings(file, cx);
1584        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1585            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1586        {
1587            return false;
1588        }
1589
1590        if let Some(last_position) = position {
1591            let settings = snapshot.settings_at(last_position, cx);
1592
1593            if !settings.edit_predictions_disabled_in.is_empty()
1594                && let Some(scope) = snapshot.language_scope_at(last_position)
1595                && let Some(scope_name) = scope.override_name()
1596                && settings
1597                    .edit_predictions_disabled_in
1598                    .iter()
1599                    .any(|s| s == scope_name)
1600            {
1601                return false;
1602            }
1603        }
1604
1605        true
1606    }
1607
1608    #[cfg(not(test))]
1609    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1610    #[cfg(test)]
1611    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1612
1613    fn queue_prediction_refresh(
1614        &mut self,
1615        project: Entity<Project>,
1616        throttle_entity: EntityId,
1617        cx: &mut Context<Self>,
1618        do_refresh: impl FnOnce(
1619            WeakEntity<Self>,
1620            &mut AsyncApp,
1621        )
1622            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1623        + 'static,
1624    ) {
1625        let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1626        let drop_on_cancel = is_ollama;
1627        let max_pending_predictions = if is_ollama { 1 } else { 2 };
1628        let project_state = self.get_or_init_project(&project, cx);
1629        let pending_prediction_id = project_state.next_pending_prediction_id;
1630        project_state.next_pending_prediction_id += 1;
1631        let last_request = project_state.last_prediction_refresh;
1632
1633        let task = cx.spawn(async move |this, cx| {
1634            if let Some((last_entity, last_timestamp)) = last_request
1635                && throttle_entity == last_entity
1636                && let Some(timeout) =
1637                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1638            {
1639                cx.background_executor().timer(timeout).await;
1640            }
1641
1642            // If this task was cancelled before the throttle timeout expired,
1643            // do not perform a request.
1644            let mut is_cancelled = true;
1645            this.update(cx, |this, cx| {
1646                let project_state = this.get_or_init_project(&project, cx);
1647                if !project_state
1648                    .cancelled_predictions
1649                    .remove(&pending_prediction_id)
1650                {
1651                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1652                    is_cancelled = false;
1653                }
1654            })
1655            .ok();
1656            if is_cancelled {
1657                return None;
1658            }
1659
1660            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1661            let new_prediction_id = new_prediction_result
1662                .as_ref()
1663                .map(|(prediction, _)| prediction.id.clone());
1664
1665            // When a prediction completes, remove it from the pending list, and cancel
1666            // any pending predictions that were enqueued before it.
1667            this.update(cx, |this, cx| {
1668                let project_state = this.get_or_init_project(&project, cx);
1669
1670                let is_cancelled = project_state
1671                    .cancelled_predictions
1672                    .remove(&pending_prediction_id);
1673
1674                let new_current_prediction = if !is_cancelled
1675                    && let Some((prediction_result, requested_by)) = new_prediction_result
1676                {
1677                    match prediction_result.prediction {
1678                        Ok(prediction) => {
1679                            let new_prediction = CurrentEditPrediction {
1680                                requested_by,
1681                                prediction,
1682                                was_shown: false,
1683                                shown_with: None,
1684                            };
1685
1686                            if let Some(current_prediction) =
1687                                project_state.current_prediction.as_ref()
1688                            {
1689                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1690                                {
1691                                    this.reject_current_prediction(
1692                                        EditPredictionRejectReason::Replaced,
1693                                        &project,
1694                                        cx,
1695                                    );
1696
1697                                    Some(new_prediction)
1698                                } else {
1699                                    this.reject_prediction(
1700                                        new_prediction.prediction.id,
1701                                        EditPredictionRejectReason::CurrentPreferred,
1702                                        false,
1703                                        cx,
1704                                    );
1705                                    None
1706                                }
1707                            } else {
1708                                Some(new_prediction)
1709                            }
1710                        }
1711                        Err(reject_reason) => {
1712                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1713                            None
1714                        }
1715                    }
1716                } else {
1717                    None
1718                };
1719
1720                let project_state = this.get_or_init_project(&project, cx);
1721
1722                if let Some(new_prediction) = new_current_prediction {
1723                    project_state.current_prediction = Some(new_prediction);
1724                }
1725
1726                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1727                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1728                    if pending_prediction.id == pending_prediction_id {
1729                        pending_predictions.remove(ix);
1730                        for pending_prediction in pending_predictions.drain(0..ix) {
1731                            project_state.cancel_pending_prediction(pending_prediction, cx)
1732                        }
1733                        break;
1734                    }
1735                }
1736                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1737                cx.notify();
1738            })
1739            .ok();
1740
1741            new_prediction_id
1742        });
1743
1744        if project_state.pending_predictions.len() < max_pending_predictions {
1745            project_state.pending_predictions.push(PendingPrediction {
1746                id: pending_prediction_id,
1747                task,
1748                drop_on_cancel,
1749            });
1750        } else {
1751            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1752            project_state.pending_predictions.push(PendingPrediction {
1753                id: pending_prediction_id,
1754                task,
1755                drop_on_cancel,
1756            });
1757            project_state.cancel_pending_prediction(pending_prediction, cx);
1758        }
1759    }
1760
1761    pub fn request_prediction(
1762        &mut self,
1763        project: &Entity<Project>,
1764        active_buffer: &Entity<Buffer>,
1765        position: language::Anchor,
1766        trigger: PredictEditsRequestTrigger,
1767        cx: &mut Context<Self>,
1768    ) -> Task<Result<Option<EditPredictionResult>>> {
1769        self.request_prediction_internal(
1770            project.clone(),
1771            active_buffer.clone(),
1772            position,
1773            trigger,
1774            cx.has_flag::<Zeta2FeatureFlag>(),
1775            cx,
1776        )
1777    }
1778
1779    fn request_prediction_internal(
1780        &mut self,
1781        project: Entity<Project>,
1782        active_buffer: Entity<Buffer>,
1783        position: language::Anchor,
1784        trigger: PredictEditsRequestTrigger,
1785        allow_jump: bool,
1786        cx: &mut Context<Self>,
1787    ) -> Task<Result<Option<EditPredictionResult>>> {
1788        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1789
1790        self.get_or_init_project(&project, cx);
1791        let project_state = self.projects.get(&project.entity_id()).unwrap();
1792        let stored_events = project_state.events(cx);
1793        let has_events = !stored_events.is_empty();
1794        let events: Vec<Arc<zeta_prompt::Event>> =
1795            stored_events.into_iter().map(|e| e.event).collect();
1796        let debug_tx = project_state.debug_tx.clone();
1797
1798        let snapshot = active_buffer.read(cx).snapshot();
1799        let cursor_point = position.to_point(&snapshot);
1800        let current_offset = position.to_offset(&snapshot);
1801
1802        let mut user_actions: Vec<UserActionRecord> =
1803            project_state.user_actions.iter().cloned().collect();
1804
1805        if let Some(last_action) = user_actions.last() {
1806            if last_action.buffer_id == active_buffer.entity_id()
1807                && current_offset != last_action.offset
1808            {
1809                let timestamp_epoch_ms = SystemTime::now()
1810                    .duration_since(UNIX_EPOCH)
1811                    .map(|d| d.as_millis() as u64)
1812                    .unwrap_or(0);
1813                user_actions.push(UserActionRecord {
1814                    action_type: UserActionType::CursorMovement,
1815                    buffer_id: active_buffer.entity_id(),
1816                    line_number: cursor_point.row,
1817                    offset: current_offset,
1818                    timestamp_epoch_ms,
1819                });
1820            }
1821        }
1822        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1823        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1824        let diagnostic_search_range =
1825            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1826
1827        let related_files = self.context_for_project(&project, cx);
1828
1829        let inputs = EditPredictionModelInput {
1830            project: project.clone(),
1831            buffer: active_buffer.clone(),
1832            snapshot: snapshot.clone(),
1833            position,
1834            events,
1835            related_files,
1836            recent_paths: project_state.recent_paths.clone(),
1837            trigger,
1838            diagnostic_search_range: diagnostic_search_range.clone(),
1839            debug_tx,
1840            user_actions,
1841        };
1842
1843        let task = match &self.edit_prediction_model {
1844            EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
1845                self,
1846                inputs,
1847                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1848                cx,
1849            ),
1850            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1851                self,
1852                inputs,
1853                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1854                cx,
1855            ),
1856            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1857            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1858            EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1859        };
1860
1861        cx.spawn(async move |this, cx| {
1862            let prediction = task.await?;
1863
1864            if prediction.is_none() && allow_jump {
1865                let cursor_point = position.to_point(&snapshot);
1866                if has_events
1867                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1868                        active_buffer.clone(),
1869                        &snapshot,
1870                        diagnostic_search_range,
1871                        cursor_point,
1872                        &project,
1873                        cx,
1874                    )
1875                    .await?
1876                {
1877                    return this
1878                        .update(cx, |this, cx| {
1879                            this.request_prediction_internal(
1880                                project,
1881                                jump_buffer,
1882                                jump_position,
1883                                trigger,
1884                                false,
1885                                cx,
1886                            )
1887                        })?
1888                        .await;
1889                }
1890
1891                return anyhow::Ok(None);
1892            }
1893
1894            Ok(prediction)
1895        })
1896    }
1897
1898    async fn next_diagnostic_location(
1899        active_buffer: Entity<Buffer>,
1900        active_buffer_snapshot: &BufferSnapshot,
1901        active_buffer_diagnostic_search_range: Range<Point>,
1902        active_buffer_cursor_point: Point,
1903        project: &Entity<Project>,
1904        cx: &mut AsyncApp,
1905    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1906        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1907        let mut jump_location = active_buffer_snapshot
1908            .diagnostic_groups(None)
1909            .into_iter()
1910            .filter_map(|(_, group)| {
1911                let range = &group.entries[group.primary_ix]
1912                    .range
1913                    .to_point(&active_buffer_snapshot);
1914                if range.overlaps(&active_buffer_diagnostic_search_range) {
1915                    None
1916                } else {
1917                    Some(range.start)
1918                }
1919            })
1920            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1921            .map(|position| {
1922                (
1923                    active_buffer.clone(),
1924                    active_buffer_snapshot.anchor_before(position),
1925                )
1926            });
1927
1928        if jump_location.is_none() {
1929            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1930                let file = buffer.file()?;
1931
1932                Some(ProjectPath {
1933                    worktree_id: file.worktree_id(cx),
1934                    path: file.path().clone(),
1935                })
1936            });
1937
1938            let buffer_task = project.update(cx, |project, cx| {
1939                let (path, _, _) = project
1940                    .diagnostic_summaries(false, cx)
1941                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1942                    .max_by_key(|(path, _, _)| {
1943                        // find the buffer with errors that shares most parent directories
1944                        path.path
1945                            .components()
1946                            .zip(
1947                                active_buffer_path
1948                                    .as_ref()
1949                                    .map(|p| p.path.components())
1950                                    .unwrap_or_default(),
1951                            )
1952                            .take_while(|(a, b)| a == b)
1953                            .count()
1954                    })?;
1955
1956                Some(project.open_buffer(path, cx))
1957            });
1958
1959            if let Some(buffer_task) = buffer_task {
1960                let closest_buffer = buffer_task.await?;
1961
1962                jump_location = closest_buffer
1963                    .read_with(cx, |buffer, _cx| {
1964                        buffer
1965                            .buffer_diagnostics(None)
1966                            .into_iter()
1967                            .min_by_key(|entry| entry.diagnostic.severity)
1968                            .map(|entry| entry.range.start)
1969                    })
1970                    .map(|position| (closest_buffer, position));
1971            }
1972        }
1973
1974        anyhow::Ok(jump_location)
1975    }
1976
1977    async fn send_raw_llm_request(
1978        request: RawCompletionRequest,
1979        client: Arc<Client>,
1980        custom_url: Option<Arc<Url>>,
1981        llm_token: LlmApiToken,
1982        app_version: Version,
1983    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1984        let url = if let Some(custom_url) = custom_url {
1985            custom_url.as_ref().clone()
1986        } else {
1987            client
1988                .http_client()
1989                .build_zed_llm_url("/predict_edits/raw", &[])?
1990        };
1991
1992        Self::send_api_request(
1993            |builder| {
1994                let req = builder
1995                    .uri(url.as_ref())
1996                    .body(serde_json::to_string(&request)?.into());
1997                Ok(req?)
1998            },
1999            client,
2000            llm_token,
2001            app_version,
2002            true,
2003        )
2004        .await
2005    }
2006
2007    pub(crate) async fn send_v3_request(
2008        input: ZetaPromptInput,
2009        client: Arc<Client>,
2010        llm_token: LlmApiToken,
2011        app_version: Version,
2012        trigger: PredictEditsRequestTrigger,
2013    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2014        let url = client
2015            .http_client()
2016            .build_zed_llm_url("/predict_edits/v3", &[])?;
2017
2018        let request = PredictEditsV3Request { input, trigger };
2019
2020        let json_bytes = serde_json::to_vec(&request)?;
2021        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2022
2023        Self::send_api_request(
2024            |builder| {
2025                let req = builder
2026                    .uri(url.as_ref())
2027                    .header("Content-Encoding", "zstd")
2028                    .body(compressed.clone().into());
2029                Ok(req?)
2030            },
2031            client,
2032            llm_token,
2033            app_version,
2034            true,
2035        )
2036        .await
2037    }
2038
2039    fn handle_api_response<T>(
2040        this: &WeakEntity<Self>,
2041        response: Result<(T, Option<EditPredictionUsage>)>,
2042        cx: &mut gpui::AsyncApp,
2043    ) -> Result<T> {
2044        match response {
2045            Ok((data, usage)) => {
2046                if let Some(usage) = usage {
2047                    this.update(cx, |this, cx| {
2048                        this.user_store.update(cx, |user_store, cx| {
2049                            user_store.update_edit_prediction_usage(usage, cx);
2050                        });
2051                    })
2052                    .ok();
2053                }
2054                Ok(data)
2055            }
2056            Err(err) => {
2057                if err.is::<ZedUpdateRequiredError>() {
2058                    cx.update(|cx| {
2059                        this.update(cx, |this, _cx| {
2060                            this.update_required = true;
2061                        })
2062                        .ok();
2063
2064                        let error_message: SharedString = err.to_string().into();
2065                        show_app_notification(
2066                            NotificationId::unique::<ZedUpdateRequiredError>(),
2067                            cx,
2068                            move |cx| {
2069                                cx.new(|cx| {
2070                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2071                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2072                                })
2073                            },
2074                        );
2075                    });
2076                }
2077                Err(err)
2078            }
2079        }
2080    }
2081
2082    async fn send_api_request<Res>(
2083        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2084        client: Arc<Client>,
2085        llm_token: LlmApiToken,
2086        app_version: Version,
2087        require_auth: bool,
2088    ) -> Result<(Res, Option<EditPredictionUsage>)>
2089    where
2090        Res: DeserializeOwned,
2091    {
2092        let http_client = client.http_client();
2093
2094        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2095            Some(custom_token)
2096        } else if require_auth {
2097            Some(llm_token.acquire(&client).await?)
2098        } else {
2099            llm_token.acquire(&client).await.ok()
2100        };
2101        let mut did_retry = false;
2102
2103        loop {
2104            let request_builder = http_client::Request::builder().method(Method::POST);
2105
2106            let mut request_builder = request_builder
2107                .header("Content-Type", "application/json")
2108                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2109
2110            // Only add Authorization header if we have a token
2111            if let Some(ref token_value) = token {
2112                request_builder =
2113                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2114            }
2115
2116            let request = build(request_builder)?;
2117
2118            let mut response = http_client.send(request).await?;
2119
2120            if let Some(minimum_required_version) = response
2121                .headers()
2122                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2123                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2124            {
2125                anyhow::ensure!(
2126                    app_version >= minimum_required_version,
2127                    ZedUpdateRequiredError {
2128                        minimum_version: minimum_required_version
2129                    }
2130                );
2131            }
2132
2133            if response.status().is_success() {
2134                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2135
2136                let mut body = Vec::new();
2137                response.body_mut().read_to_end(&mut body).await?;
2138                return Ok((serde_json::from_slice(&body)?, usage));
2139            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2140                did_retry = true;
2141                token = Some(llm_token.refresh(&client).await?);
2142            } else {
2143                let mut body = String::new();
2144                response.body_mut().read_to_string(&mut body).await?;
2145                anyhow::bail!(
2146                    "Request failed with status: {:?}\nBody: {}",
2147                    response.status(),
2148                    body
2149                );
2150            }
2151        }
2152    }
2153
2154    pub fn refresh_context(
2155        &mut self,
2156        project: &Entity<Project>,
2157        buffer: &Entity<language::Buffer>,
2158        cursor_position: language::Anchor,
2159        cx: &mut Context<Self>,
2160    ) {
2161        self.get_or_init_project(project, cx)
2162            .context
2163            .update(cx, |store, cx| {
2164                store.refresh(buffer.clone(), cursor_position, cx);
2165            });
2166    }
2167
2168    #[cfg(feature = "cli-support")]
2169    pub fn set_context_for_buffer(
2170        &mut self,
2171        project: &Entity<Project>,
2172        related_files: Vec<RelatedFile>,
2173        cx: &mut Context<Self>,
2174    ) {
2175        self.get_or_init_project(project, cx)
2176            .context
2177            .update(cx, |store, cx| {
2178                store.set_related_files(related_files, cx);
2179            });
2180    }
2181
2182    #[cfg(feature = "cli-support")]
2183    pub fn set_recent_paths_for_project(
2184        &mut self,
2185        project: &Entity<Project>,
2186        paths: impl IntoIterator<Item = project::ProjectPath>,
2187        cx: &mut Context<Self>,
2188    ) {
2189        let project_state = self.get_or_init_project(project, cx);
2190        project_state.recent_paths = paths.into_iter().collect();
2191    }
2192
2193    fn is_file_open_source(
2194        &self,
2195        project: &Entity<Project>,
2196        file: &Arc<dyn File>,
2197        cx: &App,
2198    ) -> bool {
2199        if !file.is_local() || file.is_private() {
2200            return false;
2201        }
2202        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2203            return false;
2204        };
2205        project_state
2206            .license_detection_watchers
2207            .get(&file.worktree_id(cx))
2208            .as_ref()
2209            .is_some_and(|watcher| watcher.is_project_open_source())
2210    }
2211
2212    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2213        self.data_collection_choice.is_enabled(cx)
2214    }
2215
2216    fn load_data_collection_choice() -> DataCollectionChoice {
2217        let choice = KEY_VALUE_STORE
2218            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2219            .log_err()
2220            .flatten();
2221
2222        match choice.as_deref() {
2223            Some("true") => DataCollectionChoice::Enabled,
2224            Some("false") => DataCollectionChoice::Disabled,
2225            Some(_) => {
2226                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2227                DataCollectionChoice::NotAnswered
2228            }
2229            None => DataCollectionChoice::NotAnswered,
2230        }
2231    }
2232
2233    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2234        self.data_collection_choice = self.data_collection_choice.toggle();
2235        let new_choice = self.data_collection_choice;
2236        let is_enabled = new_choice.is_enabled(cx);
2237        db::write_and_log(cx, move || {
2238            KEY_VALUE_STORE.write_kvp(
2239                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2240                is_enabled.to_string(),
2241            )
2242        });
2243    }
2244
2245    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2246        self.shown_predictions.iter()
2247    }
2248
2249    pub fn shown_completions_len(&self) -> usize {
2250        self.shown_predictions.len()
2251    }
2252
2253    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2254        self.rated_predictions.contains(id)
2255    }
2256
2257    pub fn rate_prediction(
2258        &mut self,
2259        prediction: &EditPrediction,
2260        rating: EditPredictionRating,
2261        feedback: String,
2262        cx: &mut Context<Self>,
2263    ) {
2264        let organization = self.user_store.read(cx).current_organization();
2265
2266        self.rated_predictions.insert(prediction.id.clone());
2267
2268        cx.background_spawn({
2269            let client = self.client.clone();
2270            let prediction_id = prediction.id.to_string();
2271            let inputs = serde_json::to_value(&prediction.inputs);
2272            let output = prediction
2273                .edit_preview
2274                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2275            async move {
2276                client
2277                    .cloud_client()
2278                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2279                        organization_id: organization.map(|organization| organization.id.clone()),
2280                        request_id: prediction_id,
2281                        rating: match rating {
2282                            EditPredictionRating::Positive => "positive".to_string(),
2283                            EditPredictionRating::Negative => "negative".to_string(),
2284                        },
2285                        inputs: inputs?,
2286                        output,
2287                        feedback,
2288                    })
2289                    .await?;
2290
2291                anyhow::Ok(())
2292            }
2293        })
2294        .detach_and_log_err(cx);
2295
2296        cx.notify();
2297    }
2298}
2299
2300fn merge_trailing_events_if_needed(
2301    events: &mut VecDeque<StoredEvent>,
2302    end_snapshot: &TextBufferSnapshot,
2303    latest_snapshot: &TextBufferSnapshot,
2304    latest_edit_range: &Range<Anchor>,
2305) {
2306    if let Some(last_event) = events.back() {
2307        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2308            return;
2309        }
2310    }
2311
2312    let mut next_old_event = None;
2313    let mut mergeable_count = 0;
2314    for old_event in events.iter().rev() {
2315        if let Some(next_old_event) = &next_old_event
2316            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2317        {
2318            break;
2319        }
2320        mergeable_count += 1;
2321        next_old_event = Some(old_event);
2322    }
2323
2324    if mergeable_count <= 1 {
2325        return;
2326    }
2327
2328    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2329    let oldest_event = events_to_merge.peek().unwrap();
2330    let oldest_snapshot = oldest_event.old_snapshot.clone();
2331
2332    if let Some((diff, edited_range)) =
2333        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2334    {
2335        let merged_event = match oldest_event.event.as_ref() {
2336            zeta_prompt::Event::BufferChange {
2337                old_path,
2338                path,
2339                in_open_source_repo,
2340                ..
2341            } => StoredEvent {
2342                event: Arc::new(zeta_prompt::Event::BufferChange {
2343                    old_path: old_path.clone(),
2344                    path: path.clone(),
2345                    diff,
2346                    in_open_source_repo: *in_open_source_repo,
2347                    predicted: events_to_merge.all(|e| {
2348                        matches!(
2349                            e.event.as_ref(),
2350                            zeta_prompt::Event::BufferChange {
2351                                predicted: true,
2352                                ..
2353                            }
2354                        )
2355                    }),
2356                }),
2357                old_snapshot: oldest_snapshot.clone(),
2358                edit_range: end_snapshot.anchor_before(edited_range.start)
2359                    ..end_snapshot.anchor_before(edited_range.end),
2360            },
2361        };
2362        events.truncate(events.len() - mergeable_count);
2363        events.push_back(merged_event);
2364    }
2365}
2366
2367pub(crate) fn filter_redundant_excerpts(
2368    mut related_files: Vec<RelatedFile>,
2369    cursor_path: &Path,
2370    cursor_row_range: Range<u32>,
2371) -> Vec<RelatedFile> {
2372    for file in &mut related_files {
2373        if file.path.as_ref() == cursor_path {
2374            file.excerpts.retain(|excerpt| {
2375                excerpt.row_range.start < cursor_row_range.start
2376                    || excerpt.row_range.end > cursor_row_range.end
2377            });
2378        }
2379    }
2380    related_files.retain(|file| !file.excerpts.is_empty());
2381    related_files
2382}
2383
2384#[derive(Error, Debug)]
2385#[error(
2386    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2387)]
2388pub struct ZedUpdateRequiredError {
2389    minimum_version: Version,
2390}
2391
2392#[derive(Debug, Clone, Copy)]
2393pub enum DataCollectionChoice {
2394    NotAnswered,
2395    Enabled,
2396    Disabled,
2397}
2398
2399impl DataCollectionChoice {
2400    pub fn is_enabled(self, cx: &App) -> bool {
2401        if cx.is_staff() {
2402            return true;
2403        }
2404        match self {
2405            Self::Enabled => true,
2406            Self::NotAnswered | Self::Disabled => false,
2407        }
2408    }
2409
2410    #[must_use]
2411    pub fn toggle(&self) -> DataCollectionChoice {
2412        match self {
2413            Self::Enabled => Self::Disabled,
2414            Self::Disabled => Self::Enabled,
2415            Self::NotAnswered => Self::Enabled,
2416        }
2417    }
2418}
2419
2420impl From<bool> for DataCollectionChoice {
2421    fn from(value: bool) -> Self {
2422        match value {
2423            true => DataCollectionChoice::Enabled,
2424            false => DataCollectionChoice::Disabled,
2425        }
2426    }
2427}
2428
2429struct ZedPredictUpsell;
2430
2431impl Dismissable for ZedPredictUpsell {
2432    const KEY: &'static str = "dismissed-edit-predict-upsell";
2433
2434    fn dismissed() -> bool {
2435        // To make this backwards compatible with older versions of Zed, we
2436        // check if the user has seen the previous Edit Prediction Onboarding
2437        // before, by checking the data collection choice which was written to
2438        // the database once the user clicked on "Accept and Enable"
2439        if KEY_VALUE_STORE
2440            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2441            .log_err()
2442            .is_some_and(|s| s.is_some())
2443        {
2444            return true;
2445        }
2446
2447        KEY_VALUE_STORE
2448            .read_kvp(Self::KEY)
2449            .log_err()
2450            .is_some_and(|s| s.is_some())
2451    }
2452}
2453
2454pub fn should_show_upsell_modal() -> bool {
2455    !ZedPredictUpsell::dismissed()
2456}
2457
2458pub fn init(cx: &mut App) {
2459    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2460        workspace.register_action(
2461            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2462                ZedPredictModal::toggle(
2463                    workspace,
2464                    workspace.user_store().clone(),
2465                    workspace.client().clone(),
2466                    window,
2467                    cx,
2468                )
2469            },
2470        );
2471
2472        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2473            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2474                settings
2475                    .project
2476                    .all_languages
2477                    .edit_predictions
2478                    .get_or_insert_default()
2479                    .provider = Some(EditPredictionProvider::None)
2480            });
2481        });
2482        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2483            EditPredictionStore::try_global(cx).and_then(|store| {
2484                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2485            })
2486        }
2487
2488        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2489            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2490                copilot_ui::initiate_sign_in(copilot, window, cx);
2491            }
2492        });
2493        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2494            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2495                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2496            }
2497        });
2498        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2499            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2500                copilot_ui::initiate_sign_out(copilot, window, cx);
2501            }
2502        });
2503    })
2504    .detach();
2505}