zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  13use collections::{HashMap, HashSet};
  14use command_palette_hooks::CommandPaletteFilter;
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{
  17    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  18    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  19    SyntaxIndex, SyntaxIndexState,
  20};
  21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  22use futures::channel::mpsc::UnboundedReceiver;
  23use futures::channel::{mpsc, oneshot};
  24use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
  25use gpui::BackgroundExecutor;
  26use gpui::{
  27    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  28    http_client::{self, AsyncBody, Method},
  29    prelude::*,
  30};
  31use language::{
  32    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  33};
  34use language::{BufferSnapshot, OffsetRangeExt};
  35use language_model::{LlmApiToken, RefreshLlmTokenListener};
  36use open_ai::FunctionDefinition;
  37use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  38use release_channel::AppVersion;
  39use semver::Version;
  40use serde::de::DeserializeOwned;
  41use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  42use std::any::{Any as _, TypeId};
  43use std::collections::{VecDeque, hash_map};
  44use telemetry_events::EditPredictionRating;
  45use workspace::Workspace;
  46
  47use std::ops::Range;
  48use std::path::Path;
  49use std::rc::Rc;
  50use std::str::FromStr as _;
  51use std::sync::{Arc, LazyLock};
  52use std::time::{Duration, Instant};
  53use std::{env, mem};
  54use thiserror::Error;
  55use util::rel_path::RelPathBuf;
  56use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  58
  59pub mod assemble_excerpts;
  60mod license_detection;
  61mod onboarding_modal;
  62mod prediction;
  63mod provider;
  64mod rate_prediction_modal;
  65pub mod retrieval_search;
  66pub mod sweep_ai;
  67pub mod udiff;
  68mod xml_edits;
  69pub mod zeta1;
  70
  71#[cfg(test)]
  72mod zeta_tests;
  73
  74use crate::assemble_excerpts::assemble_excerpts;
  75use crate::license_detection::LicenseDetectionWatcher;
  76use crate::onboarding_modal::ZedPredictModal;
  77pub use crate::prediction::EditPrediction;
  78pub use crate::prediction::EditPredictionId;
  79pub use crate::prediction::EditPredictionInputs;
  80use crate::prediction::EditPredictionResult;
  81use crate::rate_prediction_modal::{
  82    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  83    ThumbsUpActivePrediction,
  84};
  85pub use crate::sweep_ai::SweepAi;
  86use crate::zeta1::request_prediction_with_zeta1;
  87pub use provider::ZetaEditPredictionProvider;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Opens the rate completions modal.
  95        RateCompletions,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106
 107pub struct SweepFeatureFlag;
 108
 109impl FeatureFlag for SweepFeatureFlag {
 110    const NAME: &str = "sweep-ai";
 111}
 112pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 113    max_bytes: 512,
 114    min_bytes: 128,
 115    target_before_cursor_over_total_bytes: 0.5,
 116};
 117
 118pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 119    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 120
 121pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 122    excerpt: DEFAULT_EXCERPT_OPTIONS,
 123};
 124
 125pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 126    EditPredictionContextOptions {
 127        use_imports: true,
 128        max_retrieved_declarations: 0,
 129        excerpt: DEFAULT_EXCERPT_OPTIONS,
 130        score: EditPredictionScoreOptions {
 131            omit_excerpt_overlaps: true,
 132        },
 133    };
 134
 135pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 136    context: DEFAULT_CONTEXT_OPTIONS,
 137    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 138    max_diagnostic_bytes: 2048,
 139    prompt_format: PromptFormat::DEFAULT,
 140    file_indexing_parallelism: 1,
 141    buffer_change_grouping_interval: Duration::from_secs(1),
 142};
 143
 144static USE_OLLAMA: LazyLock<bool> =
 145    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 146static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 147    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 148        "qwen3-coder:30b".to_string()
 149    } else {
 150        "yqvev8r3".to_string()
 151    })
 152});
 153static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 154    match env::var("ZED_ZETA2_MODEL").as_deref() {
 155        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 156        Ok(model) => model,
 157        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 158        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 159    }
 160    .to_string()
 161});
 162static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 163    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 164        if *USE_OLLAMA {
 165            Some("http://localhost:11434/v1/chat/completions".into())
 166        } else {
 167            None
 168        }
 169    })
 170});
 171
 172pub struct Zeta2FeatureFlag;
 173
 174impl FeatureFlag for Zeta2FeatureFlag {
 175    const NAME: &'static str = "zeta2";
 176
 177    fn enabled_for_staff() -> bool {
 178        true
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct ZetaGlobal(Entity<Zeta>);
 184
 185impl Global for ZetaGlobal {}
 186
 187pub struct Zeta {
 188    client: Arc<Client>,
 189    user_store: Entity<UserStore>,
 190    llm_token: LlmApiToken,
 191    _llm_token_subscription: Subscription,
 192    projects: HashMap<EntityId, ZetaProject>,
 193    options: ZetaOptions,
 194    update_required: bool,
 195    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 196    #[cfg(feature = "eval-support")]
 197    eval_cache: Option<Arc<dyn EvalCache>>,
 198    edit_prediction_model: ZetaEditPredictionModel,
 199    pub sweep_ai: SweepAi,
 200    data_collection_choice: DataCollectionChoice,
 201    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 202    shown_predictions: VecDeque<EditPrediction>,
 203    rated_predictions: HashSet<EditPredictionId>,
 204}
 205
 206#[derive(Copy, Clone, Default, PartialEq, Eq)]
 207pub enum ZetaEditPredictionModel {
 208    #[default]
 209    Zeta1,
 210    Zeta2,
 211    Sweep,
 212}
 213
 214#[derive(Debug, Clone, PartialEq)]
 215pub struct ZetaOptions {
 216    pub context: ContextMode,
 217    pub max_prompt_bytes: usize,
 218    pub max_diagnostic_bytes: usize,
 219    pub prompt_format: predict_edits_v3::PromptFormat,
 220    pub file_indexing_parallelism: usize,
 221    pub buffer_change_grouping_interval: Duration,
 222}
 223
 224#[derive(Debug, Clone, PartialEq)]
 225pub enum ContextMode {
 226    Agentic(AgenticContextOptions),
 227    Syntax(EditPredictionContextOptions),
 228}
 229
 230#[derive(Debug, Clone, PartialEq)]
 231pub struct AgenticContextOptions {
 232    pub excerpt: EditPredictionExcerptOptions,
 233}
 234
 235impl ContextMode {
 236    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 237        match self {
 238            ContextMode::Agentic(options) => &options.excerpt,
 239            ContextMode::Syntax(options) => &options.excerpt,
 240        }
 241    }
 242}
 243
 244#[derive(Debug)]
 245pub enum ZetaDebugInfo {
 246    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 247    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 248    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 249    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 250    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 251}
 252
 253#[derive(Debug)]
 254pub struct ZetaContextRetrievalStartedDebugInfo {
 255    pub project: Entity<Project>,
 256    pub timestamp: Instant,
 257    pub search_prompt: String,
 258}
 259
 260#[derive(Debug)]
 261pub struct ZetaContextRetrievalDebugInfo {
 262    pub project: Entity<Project>,
 263    pub timestamp: Instant,
 264}
 265
 266#[derive(Debug)]
 267pub struct ZetaEditPredictionDebugInfo {
 268    pub inputs: EditPredictionInputs,
 269    pub retrieval_time: Duration,
 270    pub buffer: WeakEntity<Buffer>,
 271    pub position: language::Anchor,
 272    pub local_prompt: Result<String, String>,
 273    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 274}
 275
 276#[derive(Debug)]
 277pub struct ZetaSearchQueryDebugInfo {
 278    pub project: Entity<Project>,
 279    pub timestamp: Instant,
 280    pub search_queries: Vec<SearchToolQuery>,
 281}
 282
 283pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 284
 285struct ZetaProject {
 286    syntax_index: Option<Entity<SyntaxIndex>>,
 287    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 288    last_event: Option<LastEvent>,
 289    recent_paths: VecDeque<ProjectPath>,
 290    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 291    current_prediction: Option<CurrentEditPrediction>,
 292    next_pending_prediction_id: usize,
 293    pending_predictions: ArrayVec<PendingPrediction, 2>,
 294    last_prediction_refresh: Option<(EntityId, Instant)>,
 295    cancelled_predictions: HashSet<usize>,
 296    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 297    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 298    refresh_context_debounce_task: Option<Task<Option<()>>>,
 299    refresh_context_timestamp: Option<Instant>,
 300    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 301    _subscription: gpui::Subscription,
 302}
 303
 304impl ZetaProject {
 305    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 306        self.events
 307            .iter()
 308            .cloned()
 309            .chain(
 310                self.last_event
 311                    .as_ref()
 312                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 313            )
 314            .collect()
 315    }
 316
 317    fn cancel_pending_prediction(
 318        &mut self,
 319        pending_prediction: PendingPrediction,
 320        cx: &mut Context<Zeta>,
 321    ) {
 322        self.cancelled_predictions.insert(pending_prediction.id);
 323
 324        cx.spawn(async move |this, cx| {
 325            let Some(prediction_id) = pending_prediction.task.await else {
 326                return;
 327            };
 328
 329            this.update(cx, |this, _cx| {
 330                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 331            })
 332            .ok();
 333        })
 334        .detach()
 335    }
 336}
 337
 338#[derive(Debug, Clone)]
 339struct CurrentEditPrediction {
 340    pub requested_by: PredictionRequestedBy,
 341    pub prediction: EditPrediction,
 342    pub was_shown: bool,
 343}
 344
 345impl CurrentEditPrediction {
 346    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 347        let Some(new_edits) = self
 348            .prediction
 349            .interpolate(&self.prediction.buffer.read(cx))
 350        else {
 351            return false;
 352        };
 353
 354        if self.prediction.buffer != old_prediction.prediction.buffer {
 355            return true;
 356        }
 357
 358        let Some(old_edits) = old_prediction
 359            .prediction
 360            .interpolate(&old_prediction.prediction.buffer.read(cx))
 361        else {
 362            return true;
 363        };
 364
 365        let requested_by_buffer_id = self.requested_by.buffer_id();
 366
 367        // This reduces the occurrence of UI thrash from replacing edits
 368        //
 369        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 370        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 371            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 372            && old_edits.len() == 1
 373            && new_edits.len() == 1
 374        {
 375            let (old_range, old_text) = &old_edits[0];
 376            let (new_range, new_text) = &new_edits[0];
 377            new_range == old_range && new_text.starts_with(old_text.as_ref())
 378        } else {
 379            true
 380        }
 381    }
 382}
 383
 384#[derive(Debug, Clone)]
 385enum PredictionRequestedBy {
 386    DiagnosticsUpdate,
 387    Buffer(EntityId),
 388}
 389
 390impl PredictionRequestedBy {
 391    pub fn buffer_id(&self) -> Option<EntityId> {
 392        match self {
 393            PredictionRequestedBy::DiagnosticsUpdate => None,
 394            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 395        }
 396    }
 397}
 398
 399#[derive(Debug)]
 400struct PendingPrediction {
 401    id: usize,
 402    task: Task<Option<EditPredictionId>>,
 403}
 404
 405/// A prediction from the perspective of a buffer.
 406#[derive(Debug)]
 407enum BufferEditPrediction<'a> {
 408    Local { prediction: &'a EditPrediction },
 409    Jump { prediction: &'a EditPrediction },
 410}
 411
 412#[cfg(test)]
 413impl std::ops::Deref for BufferEditPrediction<'_> {
 414    type Target = EditPrediction;
 415
 416    fn deref(&self) -> &Self::Target {
 417        match self {
 418            BufferEditPrediction::Local { prediction } => prediction,
 419            BufferEditPrediction::Jump { prediction } => prediction,
 420        }
 421    }
 422}
 423
 424struct RegisteredBuffer {
 425    snapshot: BufferSnapshot,
 426    _subscriptions: [gpui::Subscription; 2],
 427}
 428
 429struct LastEvent {
 430    old_snapshot: BufferSnapshot,
 431    new_snapshot: BufferSnapshot,
 432    end_edit_anchor: Option<Anchor>,
 433}
 434
 435impl LastEvent {
 436    pub fn finalize(
 437        &self,
 438        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 439        cx: &App,
 440    ) -> Option<Arc<predict_edits_v3::Event>> {
 441        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 442        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 443
 444        let file = self.new_snapshot.file();
 445        let old_file = self.old_snapshot.file();
 446
 447        let in_open_source_repo = [file, old_file].iter().all(|file| {
 448            file.is_some_and(|file| {
 449                license_detection_watchers
 450                    .get(&file.worktree_id(cx))
 451                    .is_some_and(|watcher| watcher.is_project_open_source())
 452            })
 453        });
 454
 455        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 456
 457        if path == old_path && diff.is_empty() {
 458            None
 459        } else {
 460            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 461                old_path,
 462                path,
 463                diff,
 464                in_open_source_repo,
 465                // TODO: Actually detect if this edit was predicted or not
 466                predicted: false,
 467            }))
 468        }
 469    }
 470}
 471
 472fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 473    if let Some(file) = snapshot.file() {
 474        file.full_path(cx).into()
 475    } else {
 476        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 477    }
 478}
 479
 480impl Zeta {
 481    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 482        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 483    }
 484
 485    pub fn global(
 486        client: &Arc<Client>,
 487        user_store: &Entity<UserStore>,
 488        cx: &mut App,
 489    ) -> Entity<Self> {
 490        cx.try_global::<ZetaGlobal>()
 491            .map(|global| global.0.clone())
 492            .unwrap_or_else(|| {
 493                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 494                cx.set_global(ZetaGlobal(zeta.clone()));
 495                zeta
 496            })
 497    }
 498
 499    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 500        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 501        let data_collection_choice = Self::load_data_collection_choice();
 502
 503        let llm_token = LlmApiToken::default();
 504
 505        let (reject_tx, reject_rx) = mpsc::unbounded();
 506        cx.background_spawn({
 507            let client = client.clone();
 508            let llm_token = llm_token.clone();
 509            let app_version = AppVersion::global(cx);
 510            let background_executor = cx.background_executor().clone();
 511            async move {
 512                Self::handle_rejected_predictions(
 513                    reject_rx,
 514                    client,
 515                    llm_token,
 516                    app_version,
 517                    background_executor,
 518                )
 519                .await
 520            }
 521        })
 522        .detach();
 523
 524        Self {
 525            projects: HashMap::default(),
 526            client,
 527            user_store,
 528            options: DEFAULT_OPTIONS,
 529            llm_token,
 530            _llm_token_subscription: cx.subscribe(
 531                &refresh_llm_token_listener,
 532                |this, _listener, _event, cx| {
 533                    let client = this.client.clone();
 534                    let llm_token = this.llm_token.clone();
 535                    cx.spawn(async move |_this, _cx| {
 536                        llm_token.refresh(&client).await?;
 537                        anyhow::Ok(())
 538                    })
 539                    .detach_and_log_err(cx);
 540                },
 541            ),
 542            update_required: false,
 543            debug_tx: None,
 544            #[cfg(feature = "eval-support")]
 545            eval_cache: None,
 546            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 547            sweep_ai: SweepAi::new(cx),
 548            data_collection_choice,
 549            reject_predictions_tx: reject_tx,
 550            rated_predictions: Default::default(),
 551            shown_predictions: Default::default(),
 552        }
 553    }
 554
 555    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 556        self.edit_prediction_model = model;
 557    }
 558
 559    pub fn has_sweep_api_token(&self) -> bool {
 560        self.sweep_ai
 561            .api_token
 562            .clone()
 563            .now_or_never()
 564            .flatten()
 565            .is_some()
 566    }
 567
 568    #[cfg(feature = "eval-support")]
 569    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 570        self.eval_cache = Some(cache);
 571    }
 572
 573    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 574        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 575        self.debug_tx = Some(debug_watch_tx);
 576        debug_watch_rx
 577    }
 578
 579    pub fn options(&self) -> &ZetaOptions {
 580        &self.options
 581    }
 582
 583    pub fn set_options(&mut self, options: ZetaOptions) {
 584        self.options = options;
 585    }
 586
 587    pub fn clear_history(&mut self) {
 588        for zeta_project in self.projects.values_mut() {
 589            zeta_project.events.clear();
 590        }
 591    }
 592
 593    pub fn context_for_project(
 594        &self,
 595        project: &Entity<Project>,
 596    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 597        self.projects
 598            .get(&project.entity_id())
 599            .and_then(|project| {
 600                Some(
 601                    project
 602                        .context
 603                        .as_ref()?
 604                        .iter()
 605                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 606                )
 607            })
 608            .into_iter()
 609            .flatten()
 610    }
 611
 612    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 613        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 614            self.user_store.read(cx).edit_prediction_usage()
 615        } else {
 616            None
 617        }
 618    }
 619
 620    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 621        self.get_or_init_zeta_project(project, cx);
 622    }
 623
 624    pub fn register_buffer(
 625        &mut self,
 626        buffer: &Entity<Buffer>,
 627        project: &Entity<Project>,
 628        cx: &mut Context<Self>,
 629    ) {
 630        let zeta_project = self.get_or_init_zeta_project(project, cx);
 631        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 632    }
 633
 634    fn get_or_init_zeta_project(
 635        &mut self,
 636        project: &Entity<Project>,
 637        cx: &mut Context<Self>,
 638    ) -> &mut ZetaProject {
 639        self.projects
 640            .entry(project.entity_id())
 641            .or_insert_with(|| ZetaProject {
 642                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 643                    Some(cx.new(|cx| {
 644                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 645                    }))
 646                } else {
 647                    None
 648                },
 649                events: VecDeque::new(),
 650                last_event: None,
 651                recent_paths: VecDeque::new(),
 652                registered_buffers: HashMap::default(),
 653                current_prediction: None,
 654                cancelled_predictions: HashSet::default(),
 655                pending_predictions: ArrayVec::new(),
 656                next_pending_prediction_id: 0,
 657                last_prediction_refresh: None,
 658                context: None,
 659                refresh_context_task: None,
 660                refresh_context_debounce_task: None,
 661                refresh_context_timestamp: None,
 662                license_detection_watchers: HashMap::default(),
 663                _subscription: cx.subscribe(&project, Self::handle_project_event),
 664            })
 665    }
 666
 667    fn handle_project_event(
 668        &mut self,
 669        project: Entity<Project>,
 670        event: &project::Event,
 671        cx: &mut Context<Self>,
 672    ) {
 673        // TODO [zeta2] init with recent paths
 674        match event {
 675            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 676                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 677                    return;
 678                };
 679                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 680                if let Some(path) = path {
 681                    if let Some(ix) = zeta_project
 682                        .recent_paths
 683                        .iter()
 684                        .position(|probe| probe == &path)
 685                    {
 686                        zeta_project.recent_paths.remove(ix);
 687                    }
 688                    zeta_project.recent_paths.push_front(path);
 689                }
 690            }
 691            project::Event::DiagnosticsUpdated { .. } => {
 692                if cx.has_flag::<Zeta2FeatureFlag>() {
 693                    self.refresh_prediction_from_diagnostics(project, cx);
 694                }
 695            }
 696            _ => (),
 697        }
 698    }
 699
 700    fn register_buffer_impl<'a>(
 701        zeta_project: &'a mut ZetaProject,
 702        buffer: &Entity<Buffer>,
 703        project: &Entity<Project>,
 704        cx: &mut Context<Self>,
 705    ) -> &'a mut RegisteredBuffer {
 706        let buffer_id = buffer.entity_id();
 707
 708        if let Some(file) = buffer.read(cx).file() {
 709            let worktree_id = file.worktree_id(cx);
 710            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 711                zeta_project
 712                    .license_detection_watchers
 713                    .entry(worktree_id)
 714                    .or_insert_with(|| {
 715                        let project_entity_id = project.entity_id();
 716                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 717                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 718                            else {
 719                                return;
 720                            };
 721                            zeta_project.license_detection_watchers.remove(&worktree_id);
 722                        })
 723                        .detach();
 724                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 725                    });
 726            }
 727        }
 728
 729        match zeta_project.registered_buffers.entry(buffer_id) {
 730            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 731            hash_map::Entry::Vacant(entry) => {
 732                let snapshot = buffer.read(cx).snapshot();
 733                let project_entity_id = project.entity_id();
 734                entry.insert(RegisteredBuffer {
 735                    snapshot,
 736                    _subscriptions: [
 737                        cx.subscribe(buffer, {
 738                            let project = project.downgrade();
 739                            move |this, buffer, event, cx| {
 740                                if let language::BufferEvent::Edited = event
 741                                    && let Some(project) = project.upgrade()
 742                                {
 743                                    this.report_changes_for_buffer(&buffer, &project, cx);
 744                                }
 745                            }
 746                        }),
 747                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 748                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 749                            else {
 750                                return;
 751                            };
 752                            zeta_project.registered_buffers.remove(&buffer_id);
 753                        }),
 754                    ],
 755                })
 756            }
 757        }
 758    }
 759
 760    fn report_changes_for_buffer(
 761        &mut self,
 762        buffer: &Entity<Buffer>,
 763        project: &Entity<Project>,
 764        cx: &mut Context<Self>,
 765    ) {
 766        let project_state = self.get_or_init_zeta_project(project, cx);
 767        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 768
 769        let new_snapshot = buffer.read(cx).snapshot();
 770        if new_snapshot.version == registered_buffer.snapshot.version {
 771            return;
 772        }
 773
 774        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 775        let end_edit_anchor = new_snapshot
 776            .anchored_edits_since::<Point>(&old_snapshot.version)
 777            .last()
 778            .map(|(_, range)| range.end);
 779        let events = &mut project_state.events;
 780
 781        if let Some(LastEvent {
 782            new_snapshot: last_new_snapshot,
 783            end_edit_anchor: last_end_edit_anchor,
 784            ..
 785        }) = project_state.last_event.as_mut()
 786        {
 787            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 788                == last_new_snapshot.remote_id()
 789                && old_snapshot.version == last_new_snapshot.version;
 790
 791            let should_coalesce = is_next_snapshot_of_same_buffer
 792                && end_edit_anchor
 793                    .as_ref()
 794                    .zip(last_end_edit_anchor.as_ref())
 795                    .is_some_and(|(a, b)| {
 796                        let a = a.to_point(&new_snapshot);
 797                        let b = b.to_point(&new_snapshot);
 798                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 799                    });
 800
 801            if should_coalesce {
 802                *last_end_edit_anchor = end_edit_anchor;
 803                *last_new_snapshot = new_snapshot;
 804                return;
 805            }
 806        }
 807
 808        if events.len() + 1 >= EVENT_COUNT_MAX {
 809            events.pop_front();
 810        }
 811
 812        if let Some(event) = project_state.last_event.take() {
 813            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 814        }
 815
 816        project_state.last_event = Some(LastEvent {
 817            old_snapshot,
 818            new_snapshot,
 819            end_edit_anchor,
 820        });
 821    }
 822
 823    fn current_prediction_for_buffer(
 824        &self,
 825        buffer: &Entity<Buffer>,
 826        project: &Entity<Project>,
 827        cx: &App,
 828    ) -> Option<BufferEditPrediction<'_>> {
 829        let project_state = self.projects.get(&project.entity_id())?;
 830
 831        let CurrentEditPrediction {
 832            requested_by,
 833            prediction,
 834            ..
 835        } = project_state.current_prediction.as_ref()?;
 836
 837        if prediction.targets_buffer(buffer.read(cx)) {
 838            Some(BufferEditPrediction::Local { prediction })
 839        } else {
 840            let show_jump = match requested_by {
 841                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 842                    requested_by_buffer_id == &buffer.entity_id()
 843                }
 844                PredictionRequestedBy::DiagnosticsUpdate => true,
 845            };
 846
 847            if show_jump {
 848                Some(BufferEditPrediction::Jump { prediction })
 849            } else {
 850                None
 851            }
 852        }
 853    }
 854
 855    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 856        match self.edit_prediction_model {
 857            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 858            ZetaEditPredictionModel::Sweep => return,
 859        }
 860
 861        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 862            return;
 863        };
 864
 865        let Some(prediction) = project_state.current_prediction.take() else {
 866            return;
 867        };
 868        let request_id = prediction.prediction.id.to_string();
 869        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 870            project_state.cancel_pending_prediction(pending_prediction, cx);
 871        }
 872
 873        let client = self.client.clone();
 874        let llm_token = self.llm_token.clone();
 875        let app_version = AppVersion::global(cx);
 876        cx.spawn(async move |this, cx| {
 877            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 878                http_client::Url::parse(&predict_edits_url)?
 879            } else {
 880                client
 881                    .http_client()
 882                    .build_zed_llm_url("/predict_edits/accept", &[])?
 883            };
 884
 885            let response = cx
 886                .background_spawn(Self::send_api_request::<()>(
 887                    move |builder| {
 888                        let req = builder.uri(url.as_ref()).body(
 889                            serde_json::to_string(&AcceptEditPredictionBody {
 890                                request_id: request_id.clone(),
 891                            })?
 892                            .into(),
 893                        );
 894                        Ok(req?)
 895                    },
 896                    client,
 897                    llm_token,
 898                    app_version,
 899                ))
 900                .await;
 901
 902            Self::handle_api_response(&this, response, cx)?;
 903            anyhow::Ok(())
 904        })
 905        .detach_and_log_err(cx);
 906    }
 907
 908    async fn handle_rejected_predictions(
 909        rx: UnboundedReceiver<EditPredictionRejection>,
 910        client: Arc<Client>,
 911        llm_token: LlmApiToken,
 912        app_version: Version,
 913        background_executor: BackgroundExecutor,
 914    ) {
 915        let mut rx = std::pin::pin!(rx.peekable());
 916        let mut batched = Vec::new();
 917
 918        while let Some(rejection) = rx.next().await {
 919            batched.push(rejection);
 920
 921            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
 922                select_biased! {
 923                    next = rx.as_mut().peek().fuse() => {
 924                        if next.is_some() {
 925                            continue;
 926                        }
 927                    }
 928                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
 929                }
 930            }
 931
 932            let url = client
 933                .http_client()
 934                .build_zed_llm_url("/predict_edits/reject", &[])
 935                .unwrap();
 936
 937            let flush_count = batched
 938                .len()
 939                // in case items have accumulated after failure
 940                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
 941            let start = batched.len() - flush_count;
 942
 943            let body = RejectEditPredictionsBodyRef {
 944                rejections: &batched[start..],
 945            };
 946
 947            let result = Self::send_api_request::<()>(
 948                |builder| {
 949                    let req = builder
 950                        .uri(url.as_ref())
 951                        .body(serde_json::to_string(&body)?.into());
 952                    anyhow::Ok(req?)
 953                },
 954                client.clone(),
 955                llm_token.clone(),
 956                app_version.clone(),
 957            )
 958            .await;
 959
 960            if result.log_err().is_some() {
 961                batched.drain(start..);
 962            }
 963        }
 964    }
 965
 966    fn reject_current_prediction(
 967        &mut self,
 968        reason: EditPredictionRejectReason,
 969        project: &Entity<Project>,
 970    ) {
 971        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 972            project_state.pending_predictions.clear();
 973            if let Some(prediction) = project_state.current_prediction.take() {
 974                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
 975            }
 976        };
 977    }
 978
 979    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 980        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 981            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 982                if !current_prediction.was_shown {
 983                    current_prediction.was_shown = true;
 984                    self.shown_predictions
 985                        .push_front(current_prediction.prediction.clone());
 986                    if self.shown_predictions.len() > 50 {
 987                        let completion = self.shown_predictions.pop_back().unwrap();
 988                        self.rated_predictions.remove(&completion.id);
 989                    }
 990                }
 991            }
 992        }
 993    }
 994
 995    fn reject_prediction(
 996        &mut self,
 997        prediction_id: EditPredictionId,
 998        reason: EditPredictionRejectReason,
 999        was_shown: bool,
1000    ) {
1001        self.reject_predictions_tx
1002            .unbounded_send(EditPredictionRejection {
1003                request_id: prediction_id.to_string(),
1004                reason,
1005                was_shown,
1006            })
1007            .log_err();
1008    }
1009
1010    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1011        self.projects
1012            .get(&project.entity_id())
1013            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1014    }
1015
1016    pub fn refresh_prediction_from_buffer(
1017        &mut self,
1018        project: Entity<Project>,
1019        buffer: Entity<Buffer>,
1020        position: language::Anchor,
1021        cx: &mut Context<Self>,
1022    ) {
1023        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1024            let Some(request_task) = this
1025                .update(cx, |this, cx| {
1026                    this.request_prediction(
1027                        &project,
1028                        &buffer,
1029                        position,
1030                        PredictEditsRequestTrigger::Other,
1031                        cx,
1032                    )
1033                })
1034                .log_err()
1035            else {
1036                return Task::ready(anyhow::Ok(None));
1037            };
1038
1039            cx.spawn(async move |_cx| {
1040                request_task.await.map(|prediction_result| {
1041                    prediction_result.map(|prediction_result| {
1042                        (
1043                            prediction_result,
1044                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1045                        )
1046                    })
1047                })
1048            })
1049        })
1050    }
1051
1052    pub fn refresh_prediction_from_diagnostics(
1053        &mut self,
1054        project: Entity<Project>,
1055        cx: &mut Context<Self>,
1056    ) {
1057        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1058            return;
1059        };
1060
1061        // Prefer predictions from buffer
1062        if zeta_project.current_prediction.is_some() {
1063            return;
1064        };
1065
1066        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1067            let Some(open_buffer_task) = project
1068                .update(cx, |project, cx| {
1069                    project
1070                        .active_entry()
1071                        .and_then(|entry| project.path_for_entry(entry, cx))
1072                        .map(|path| project.open_buffer(path, cx))
1073                })
1074                .log_err()
1075                .flatten()
1076            else {
1077                return Task::ready(anyhow::Ok(None));
1078            };
1079
1080            cx.spawn(async move |cx| {
1081                let active_buffer = open_buffer_task.await?;
1082                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1083
1084                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1085                    active_buffer,
1086                    &snapshot,
1087                    Default::default(),
1088                    Default::default(),
1089                    &project,
1090                    cx,
1091                )
1092                .await?
1093                else {
1094                    return anyhow::Ok(None);
1095                };
1096
1097                let Some(prediction_result) = this
1098                    .update(cx, |this, cx| {
1099                        this.request_prediction(
1100                            &project,
1101                            &jump_buffer,
1102                            jump_position,
1103                            PredictEditsRequestTrigger::Diagnostics,
1104                            cx,
1105                        )
1106                    })?
1107                    .await?
1108                else {
1109                    return anyhow::Ok(None);
1110                };
1111
1112                this.update(cx, |this, cx| {
1113                    Some((
1114                        if this
1115                            .get_or_init_zeta_project(&project, cx)
1116                            .current_prediction
1117                            .is_none()
1118                        {
1119                            prediction_result
1120                        } else {
1121                            EditPredictionResult {
1122                                id: prediction_result.id,
1123                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1124                            }
1125                        },
1126                        PredictionRequestedBy::DiagnosticsUpdate,
1127                    ))
1128                })
1129            })
1130        });
1131    }
1132
1133    #[cfg(not(test))]
1134    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1135    #[cfg(test)]
1136    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1137
1138    fn queue_prediction_refresh(
1139        &mut self,
1140        project: Entity<Project>,
1141        throttle_entity: EntityId,
1142        cx: &mut Context<Self>,
1143        do_refresh: impl FnOnce(
1144            WeakEntity<Self>,
1145            &mut AsyncApp,
1146        )
1147            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1148        + 'static,
1149    ) {
1150        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1151        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1152        zeta_project.next_pending_prediction_id += 1;
1153        let last_request = zeta_project.last_prediction_refresh;
1154
1155        let task = cx.spawn(async move |this, cx| {
1156            if let Some((last_entity, last_timestamp)) = last_request
1157                && throttle_entity == last_entity
1158                && let Some(timeout) =
1159                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1160            {
1161                cx.background_executor().timer(timeout).await;
1162            }
1163
1164            // If this task was cancelled before the throttle timeout expired,
1165            // do not perform a request.
1166            let mut is_cancelled = true;
1167            this.update(cx, |this, cx| {
1168                let project_state = this.get_or_init_zeta_project(&project, cx);
1169                if !project_state
1170                    .cancelled_predictions
1171                    .remove(&pending_prediction_id)
1172                {
1173                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1174                    is_cancelled = false;
1175                }
1176            })
1177            .ok();
1178            if is_cancelled {
1179                return None;
1180            }
1181
1182            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1183            let new_prediction_id = new_prediction_result
1184                .as_ref()
1185                .map(|(prediction, _)| prediction.id.clone());
1186
1187            // When a prediction completes, remove it from the pending list, and cancel
1188            // any pending predictions that were enqueued before it.
1189            this.update(cx, |this, cx| {
1190                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1191
1192                let is_cancelled = zeta_project
1193                    .cancelled_predictions
1194                    .remove(&pending_prediction_id);
1195
1196                let new_current_prediction = if !is_cancelled
1197                    && let Some((prediction_result, requested_by)) = new_prediction_result
1198                {
1199                    match prediction_result.prediction {
1200                        Ok(prediction) => {
1201                            let new_prediction = CurrentEditPrediction {
1202                                requested_by,
1203                                prediction,
1204                                was_shown: false,
1205                            };
1206
1207                            if let Some(current_prediction) =
1208                                zeta_project.current_prediction.as_ref()
1209                            {
1210                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1211                                {
1212                                    this.reject_current_prediction(
1213                                        EditPredictionRejectReason::Replaced,
1214                                        &project,
1215                                    );
1216
1217                                    Some(new_prediction)
1218                                } else {
1219                                    this.reject_prediction(
1220                                        new_prediction.prediction.id,
1221                                        EditPredictionRejectReason::CurrentPreferred,
1222                                        false,
1223                                    );
1224                                    None
1225                                }
1226                            } else {
1227                                Some(new_prediction)
1228                            }
1229                        }
1230                        Err(reject_reason) => {
1231                            this.reject_prediction(prediction_result.id, reject_reason, false);
1232                            None
1233                        }
1234                    }
1235                } else {
1236                    None
1237                };
1238
1239                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1240
1241                if let Some(new_prediction) = new_current_prediction {
1242                    zeta_project.current_prediction = Some(new_prediction);
1243                }
1244
1245                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1246                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1247                    if pending_prediction.id == pending_prediction_id {
1248                        pending_predictions.remove(ix);
1249                        for pending_prediction in pending_predictions.drain(0..ix) {
1250                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1251                        }
1252                        break;
1253                    }
1254                }
1255                this.get_or_init_zeta_project(&project, cx)
1256                    .pending_predictions = pending_predictions;
1257                cx.notify();
1258            })
1259            .ok();
1260
1261            new_prediction_id
1262        });
1263
1264        if zeta_project.pending_predictions.len() <= 1 {
1265            zeta_project.pending_predictions.push(PendingPrediction {
1266                id: pending_prediction_id,
1267                task,
1268            });
1269        } else if zeta_project.pending_predictions.len() == 2 {
1270            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1271            zeta_project.pending_predictions.push(PendingPrediction {
1272                id: pending_prediction_id,
1273                task,
1274            });
1275            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1276        }
1277    }
1278
1279    pub fn request_prediction(
1280        &mut self,
1281        project: &Entity<Project>,
1282        active_buffer: &Entity<Buffer>,
1283        position: language::Anchor,
1284        trigger: PredictEditsRequestTrigger,
1285        cx: &mut Context<Self>,
1286    ) -> Task<Result<Option<EditPredictionResult>>> {
1287        self.request_prediction_internal(
1288            project.clone(),
1289            active_buffer.clone(),
1290            position,
1291            trigger,
1292            cx.has_flag::<Zeta2FeatureFlag>(),
1293            cx,
1294        )
1295    }
1296
1297    fn request_prediction_internal(
1298        &mut self,
1299        project: Entity<Project>,
1300        active_buffer: Entity<Buffer>,
1301        position: language::Anchor,
1302        trigger: PredictEditsRequestTrigger,
1303        allow_jump: bool,
1304        cx: &mut Context<Self>,
1305    ) -> Task<Result<Option<EditPredictionResult>>> {
1306        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1307
1308        self.get_or_init_zeta_project(&project, cx);
1309        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1310        let events = zeta_project.events(cx);
1311        let has_events = !events.is_empty();
1312
1313        let snapshot = active_buffer.read(cx).snapshot();
1314        let cursor_point = position.to_point(&snapshot);
1315        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1316        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1317        let diagnostic_search_range =
1318            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1319
1320        let task = match self.edit_prediction_model {
1321            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1322                self,
1323                &project,
1324                &active_buffer,
1325                snapshot.clone(),
1326                position,
1327                events,
1328                trigger,
1329                cx,
1330            ),
1331            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1332                &project,
1333                &active_buffer,
1334                snapshot.clone(),
1335                position,
1336                events,
1337                trigger,
1338                cx,
1339            ),
1340            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1341                &project,
1342                &active_buffer,
1343                snapshot.clone(),
1344                position,
1345                events,
1346                &zeta_project.recent_paths,
1347                diagnostic_search_range.clone(),
1348                cx,
1349            ),
1350        };
1351
1352        cx.spawn(async move |this, cx| {
1353            let prediction = task.await?;
1354
1355            if prediction.is_none() && allow_jump {
1356                let cursor_point = position.to_point(&snapshot);
1357                if has_events
1358                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1359                        active_buffer.clone(),
1360                        &snapshot,
1361                        diagnostic_search_range,
1362                        cursor_point,
1363                        &project,
1364                        cx,
1365                    )
1366                    .await?
1367                {
1368                    return this
1369                        .update(cx, |this, cx| {
1370                            this.request_prediction_internal(
1371                                project,
1372                                jump_buffer,
1373                                jump_position,
1374                                trigger,
1375                                false,
1376                                cx,
1377                            )
1378                        })?
1379                        .await;
1380                }
1381
1382                return anyhow::Ok(None);
1383            }
1384
1385            Ok(prediction)
1386        })
1387    }
1388
1389    async fn next_diagnostic_location(
1390        active_buffer: Entity<Buffer>,
1391        active_buffer_snapshot: &BufferSnapshot,
1392        active_buffer_diagnostic_search_range: Range<Point>,
1393        active_buffer_cursor_point: Point,
1394        project: &Entity<Project>,
1395        cx: &mut AsyncApp,
1396    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1397        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1398        let mut jump_location = active_buffer_snapshot
1399            .diagnostic_groups(None)
1400            .into_iter()
1401            .filter_map(|(_, group)| {
1402                let range = &group.entries[group.primary_ix]
1403                    .range
1404                    .to_point(&active_buffer_snapshot);
1405                if range.overlaps(&active_buffer_diagnostic_search_range) {
1406                    None
1407                } else {
1408                    Some(range.start)
1409                }
1410            })
1411            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1412            .map(|position| {
1413                (
1414                    active_buffer.clone(),
1415                    active_buffer_snapshot.anchor_before(position),
1416                )
1417            });
1418
1419        if jump_location.is_none() {
1420            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1421                let file = buffer.file()?;
1422
1423                Some(ProjectPath {
1424                    worktree_id: file.worktree_id(cx),
1425                    path: file.path().clone(),
1426                })
1427            })?;
1428
1429            let buffer_task = project.update(cx, |project, cx| {
1430                let (path, _, _) = project
1431                    .diagnostic_summaries(false, cx)
1432                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1433                    .max_by_key(|(path, _, _)| {
1434                        // find the buffer with errors that shares most parent directories
1435                        path.path
1436                            .components()
1437                            .zip(
1438                                active_buffer_path
1439                                    .as_ref()
1440                                    .map(|p| p.path.components())
1441                                    .unwrap_or_default(),
1442                            )
1443                            .take_while(|(a, b)| a == b)
1444                            .count()
1445                    })?;
1446
1447                Some(project.open_buffer(path, cx))
1448            })?;
1449
1450            if let Some(buffer_task) = buffer_task {
1451                let closest_buffer = buffer_task.await?;
1452
1453                jump_location = closest_buffer
1454                    .read_with(cx, |buffer, _cx| {
1455                        buffer
1456                            .buffer_diagnostics(None)
1457                            .into_iter()
1458                            .min_by_key(|entry| entry.diagnostic.severity)
1459                            .map(|entry| entry.range.start)
1460                    })?
1461                    .map(|position| (closest_buffer, position));
1462            }
1463        }
1464
1465        anyhow::Ok(jump_location)
1466    }
1467
1468    fn request_prediction_with_zeta2(
1469        &mut self,
1470        project: &Entity<Project>,
1471        active_buffer: &Entity<Buffer>,
1472        active_snapshot: BufferSnapshot,
1473        position: language::Anchor,
1474        events: Vec<Arc<Event>>,
1475        trigger: PredictEditsRequestTrigger,
1476        cx: &mut Context<Self>,
1477    ) -> Task<Result<Option<EditPredictionResult>>> {
1478        let project_state = self.projects.get(&project.entity_id());
1479
1480        let index_state = project_state.and_then(|state| {
1481            state
1482                .syntax_index
1483                .as_ref()
1484                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1485        });
1486        let options = self.options.clone();
1487        let buffer_snapshotted_at = Instant::now();
1488        let Some(excerpt_path) = active_snapshot
1489            .file()
1490            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1491        else {
1492            return Task::ready(Err(anyhow!("No file path for excerpt")));
1493        };
1494        let client = self.client.clone();
1495        let llm_token = self.llm_token.clone();
1496        let app_version = AppVersion::global(cx);
1497        let worktree_snapshots = project
1498            .read(cx)
1499            .worktrees(cx)
1500            .map(|worktree| worktree.read(cx).snapshot())
1501            .collect::<Vec<_>>();
1502        let debug_tx = self.debug_tx.clone();
1503
1504        let diagnostics = active_snapshot.diagnostic_sets().clone();
1505
1506        let file = active_buffer.read(cx).file();
1507        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1508            let mut path = f.worktree.read(cx).absolutize(&f.path);
1509            if path.pop() { Some(path) } else { None }
1510        });
1511
1512        // TODO data collection
1513        let can_collect_data = file
1514            .as_ref()
1515            .map_or(false, |file| self.can_collect_file(project, file, cx));
1516
1517        let empty_context_files = HashMap::default();
1518        let context_files = project_state
1519            .and_then(|project_state| project_state.context.as_ref())
1520            .unwrap_or(&empty_context_files);
1521
1522        #[cfg(feature = "eval-support")]
1523        let parsed_fut = futures::future::join_all(
1524            context_files
1525                .keys()
1526                .map(|buffer| buffer.read(cx).parsing_idle()),
1527        );
1528
1529        let mut included_files = context_files
1530            .iter()
1531            .filter_map(|(buffer_entity, ranges)| {
1532                let buffer = buffer_entity.read(cx);
1533                Some((
1534                    buffer_entity.clone(),
1535                    buffer.snapshot(),
1536                    buffer.file()?.full_path(cx).into(),
1537                    ranges.clone(),
1538                ))
1539            })
1540            .collect::<Vec<_>>();
1541
1542        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1543            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1544        });
1545
1546        #[cfg(feature = "eval-support")]
1547        let eval_cache = self.eval_cache.clone();
1548
1549        let request_task = cx.background_spawn({
1550            let active_buffer = active_buffer.clone();
1551            async move {
1552                #[cfg(feature = "eval-support")]
1553                parsed_fut.await;
1554
1555                let index_state = if let Some(index_state) = index_state {
1556                    Some(index_state.lock_owned().await)
1557                } else {
1558                    None
1559                };
1560
1561                let cursor_offset = position.to_offset(&active_snapshot);
1562                let cursor_point = cursor_offset.to_point(&active_snapshot);
1563
1564                let before_retrieval = Instant::now();
1565
1566                let (diagnostic_groups, diagnostic_groups_truncated) =
1567                    Self::gather_nearby_diagnostics(
1568                        cursor_offset,
1569                        &diagnostics,
1570                        &active_snapshot,
1571                        options.max_diagnostic_bytes,
1572                    );
1573
1574                let cloud_request = match options.context {
1575                    ContextMode::Agentic(context_options) => {
1576                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1577                            cursor_point,
1578                            &active_snapshot,
1579                            &context_options.excerpt,
1580                            index_state.as_deref(),
1581                        ) else {
1582                            return Ok((None, None));
1583                        };
1584
1585                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1586                            ..active_snapshot.anchor_before(excerpt.range.end);
1587
1588                        if let Some(buffer_ix) =
1589                            included_files.iter().position(|(_, snapshot, _, _)| {
1590                                snapshot.remote_id() == active_snapshot.remote_id()
1591                            })
1592                        {
1593                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1594                            ranges.push(excerpt_anchor_range);
1595                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1596                            let last_ix = included_files.len() - 1;
1597                            included_files.swap(buffer_ix, last_ix);
1598                        } else {
1599                            included_files.push((
1600                                active_buffer.clone(),
1601                                active_snapshot.clone(),
1602                                excerpt_path.clone(),
1603                                vec![excerpt_anchor_range],
1604                            ));
1605                        }
1606
1607                        let included_files = included_files
1608                            .iter()
1609                            .map(|(_, snapshot, path, ranges)| {
1610                                let ranges = ranges
1611                                    .iter()
1612                                    .map(|range| {
1613                                        let point_range = range.to_point(&snapshot);
1614                                        Line(point_range.start.row)..Line(point_range.end.row)
1615                                    })
1616                                    .collect::<Vec<_>>();
1617                                let excerpts = assemble_excerpts(&snapshot, ranges);
1618                                predict_edits_v3::IncludedFile {
1619                                    path: path.clone(),
1620                                    max_row: Line(snapshot.max_point().row),
1621                                    excerpts,
1622                                }
1623                            })
1624                            .collect::<Vec<_>>();
1625
1626                        predict_edits_v3::PredictEditsRequest {
1627                            excerpt_path,
1628                            excerpt: String::new(),
1629                            excerpt_line_range: Line(0)..Line(0),
1630                            excerpt_range: 0..0,
1631                            cursor_point: predict_edits_v3::Point {
1632                                line: predict_edits_v3::Line(cursor_point.row),
1633                                column: cursor_point.column,
1634                            },
1635                            included_files,
1636                            referenced_declarations: vec![],
1637                            events,
1638                            can_collect_data,
1639                            diagnostic_groups,
1640                            diagnostic_groups_truncated,
1641                            debug_info: debug_tx.is_some(),
1642                            prompt_max_bytes: Some(options.max_prompt_bytes),
1643                            prompt_format: options.prompt_format,
1644                            // TODO [zeta2]
1645                            signatures: vec![],
1646                            excerpt_parent: None,
1647                            git_info: None,
1648                            trigger,
1649                        }
1650                    }
1651                    ContextMode::Syntax(context_options) => {
1652                        let Some(context) = EditPredictionContext::gather_context(
1653                            cursor_point,
1654                            &active_snapshot,
1655                            parent_abs_path.as_deref(),
1656                            &context_options,
1657                            index_state.as_deref(),
1658                        ) else {
1659                            return Ok((None, None));
1660                        };
1661
1662                        make_syntax_context_cloud_request(
1663                            excerpt_path,
1664                            context,
1665                            events,
1666                            can_collect_data,
1667                            diagnostic_groups,
1668                            diagnostic_groups_truncated,
1669                            None,
1670                            debug_tx.is_some(),
1671                            &worktree_snapshots,
1672                            index_state.as_deref(),
1673                            Some(options.max_prompt_bytes),
1674                            options.prompt_format,
1675                            trigger,
1676                        )
1677                    }
1678                };
1679
1680                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1681
1682                let inputs = EditPredictionInputs {
1683                    included_files: cloud_request.included_files,
1684                    events: cloud_request.events,
1685                    cursor_point: cloud_request.cursor_point,
1686                    cursor_path: cloud_request.excerpt_path,
1687                };
1688
1689                let retrieval_time = Instant::now() - before_retrieval;
1690
1691                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1692                    let (response_tx, response_rx) = oneshot::channel();
1693
1694                    debug_tx
1695                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1696                            ZetaEditPredictionDebugInfo {
1697                                inputs: inputs.clone(),
1698                                retrieval_time,
1699                                buffer: active_buffer.downgrade(),
1700                                local_prompt: match prompt_result.as_ref() {
1701                                    Ok((prompt, _)) => Ok(prompt.clone()),
1702                                    Err(err) => Err(err.to_string()),
1703                                },
1704                                position,
1705                                response_rx,
1706                            },
1707                        ))
1708                        .ok();
1709                    Some(response_tx)
1710                } else {
1711                    None
1712                };
1713
1714                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1715                    if let Some(debug_response_tx) = debug_response_tx {
1716                        debug_response_tx
1717                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1718                            .ok();
1719                    }
1720                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1721                }
1722
1723                let (prompt, _) = prompt_result?;
1724                let generation_params =
1725                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1726                let request = open_ai::Request {
1727                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1728                    messages: vec![open_ai::RequestMessage::User {
1729                        content: open_ai::MessageContent::Plain(prompt),
1730                    }],
1731                    stream: false,
1732                    max_completion_tokens: None,
1733                    stop: generation_params.stop.unwrap_or_default(),
1734                    temperature: generation_params.temperature.unwrap_or(0.7),
1735                    tool_choice: None,
1736                    parallel_tool_calls: None,
1737                    tools: vec![],
1738                    prompt_cache_key: None,
1739                    reasoning_effort: None,
1740                };
1741
1742                log::trace!("Sending edit prediction request");
1743
1744                let before_request = Instant::now();
1745                let response = Self::send_raw_llm_request(
1746                    request,
1747                    client,
1748                    llm_token,
1749                    app_version,
1750                    #[cfg(feature = "eval-support")]
1751                    eval_cache,
1752                    #[cfg(feature = "eval-support")]
1753                    EvalCacheEntryKind::Prediction,
1754                )
1755                .await;
1756                let received_response_at = Instant::now();
1757                let request_time = received_response_at - before_request;
1758
1759                log::trace!("Got edit prediction response");
1760
1761                if let Some(debug_response_tx) = debug_response_tx {
1762                    debug_response_tx
1763                        .send((
1764                            response
1765                                .as_ref()
1766                                .map_err(|err| err.to_string())
1767                                .map(|response| response.0.clone()),
1768                            request_time,
1769                        ))
1770                        .ok();
1771                }
1772
1773                let (res, usage) = response?;
1774                let request_id = EditPredictionId(res.id.clone().into());
1775                let Some(mut output_text) = text_from_response(res) else {
1776                    return Ok((Some((request_id, None)), usage));
1777                };
1778
1779                if output_text.contains(CURSOR_MARKER) {
1780                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1781                    output_text = output_text.replace(CURSOR_MARKER, "");
1782                }
1783
1784                let get_buffer_from_context = |path: &Path| {
1785                    included_files
1786                        .iter()
1787                        .find_map(|(_, buffer, probe_path, ranges)| {
1788                            if probe_path.as_ref() == path {
1789                                Some((buffer, ranges.as_slice()))
1790                            } else {
1791                                None
1792                            }
1793                        })
1794                };
1795
1796                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1797                    PromptFormat::NumLinesUniDiff => {
1798                        // TODO: Implement parsing of multi-file diffs
1799                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1800                    }
1801                    PromptFormat::Minimal
1802                    | PromptFormat::MinimalQwen
1803                    | PromptFormat::SeedCoder1120 => {
1804                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1805                            let edits = vec![];
1806                            (&active_snapshot, edits)
1807                        } else {
1808                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1809                        }
1810                    }
1811                    PromptFormat::OldTextNewText => {
1812                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1813                            .await?
1814                    }
1815                    _ => {
1816                        bail!("unsupported prompt format {}", options.prompt_format)
1817                    }
1818                };
1819
1820                let edited_buffer = included_files
1821                    .iter()
1822                    .find_map(|(buffer, snapshot, _, _)| {
1823                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1824                            Some(buffer.clone())
1825                        } else {
1826                            None
1827                        }
1828                    })
1829                    .context("Failed to find buffer in included_buffers")?;
1830
1831                anyhow::Ok((
1832                    Some((
1833                        request_id,
1834                        Some((
1835                            inputs,
1836                            edited_buffer,
1837                            edited_buffer_snapshot.clone(),
1838                            edits,
1839                            received_response_at,
1840                        )),
1841                    )),
1842                    usage,
1843                ))
1844            }
1845        });
1846
1847        cx.spawn({
1848            async move |this, cx| {
1849                let Some((id, prediction)) =
1850                    Self::handle_api_response(&this, request_task.await, cx)?
1851                else {
1852                    return Ok(None);
1853                };
1854
1855                let Some((
1856                    inputs,
1857                    edited_buffer,
1858                    edited_buffer_snapshot,
1859                    edits,
1860                    received_response_at,
1861                )) = prediction
1862                else {
1863                    return Ok(Some(EditPredictionResult {
1864                        id,
1865                        prediction: Err(EditPredictionRejectReason::Empty),
1866                    }));
1867                };
1868
1869                // TODO telemetry: duration, etc
1870                Ok(Some(
1871                    EditPredictionResult::new(
1872                        id,
1873                        &edited_buffer,
1874                        &edited_buffer_snapshot,
1875                        edits.into(),
1876                        buffer_snapshotted_at,
1877                        received_response_at,
1878                        inputs,
1879                        cx,
1880                    )
1881                    .await,
1882                ))
1883            }
1884        })
1885    }
1886
1887    async fn send_raw_llm_request(
1888        request: open_ai::Request,
1889        client: Arc<Client>,
1890        llm_token: LlmApiToken,
1891        app_version: Version,
1892        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1893        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1894    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1895        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1896            http_client::Url::parse(&predict_edits_url)?
1897        } else {
1898            client
1899                .http_client()
1900                .build_zed_llm_url("/predict_edits/raw", &[])?
1901        };
1902
1903        #[cfg(feature = "eval-support")]
1904        let cache_key = if let Some(cache) = eval_cache {
1905            use collections::FxHasher;
1906            use std::hash::{Hash, Hasher};
1907
1908            let mut hasher = FxHasher::default();
1909            url.hash(&mut hasher);
1910            let request_str = serde_json::to_string_pretty(&request)?;
1911            request_str.hash(&mut hasher);
1912            let hash = hasher.finish();
1913
1914            let key = (eval_cache_kind, hash);
1915            if let Some(response_str) = cache.read(key) {
1916                return Ok((serde_json::from_str(&response_str)?, None));
1917            }
1918
1919            Some((cache, request_str, key))
1920        } else {
1921            None
1922        };
1923
1924        let (response, usage) = Self::send_api_request(
1925            |builder| {
1926                let req = builder
1927                    .uri(url.as_ref())
1928                    .body(serde_json::to_string(&request)?.into());
1929                Ok(req?)
1930            },
1931            client,
1932            llm_token,
1933            app_version,
1934        )
1935        .await?;
1936
1937        #[cfg(feature = "eval-support")]
1938        if let Some((cache, request, key)) = cache_key {
1939            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1940        }
1941
1942        Ok((response, usage))
1943    }
1944
1945    fn handle_api_response<T>(
1946        this: &WeakEntity<Self>,
1947        response: Result<(T, Option<EditPredictionUsage>)>,
1948        cx: &mut gpui::AsyncApp,
1949    ) -> Result<T> {
1950        match response {
1951            Ok((data, usage)) => {
1952                if let Some(usage) = usage {
1953                    this.update(cx, |this, cx| {
1954                        this.user_store.update(cx, |user_store, cx| {
1955                            user_store.update_edit_prediction_usage(usage, cx);
1956                        });
1957                    })
1958                    .ok();
1959                }
1960                Ok(data)
1961            }
1962            Err(err) => {
1963                if err.is::<ZedUpdateRequiredError>() {
1964                    cx.update(|cx| {
1965                        this.update(cx, |this, _cx| {
1966                            this.update_required = true;
1967                        })
1968                        .ok();
1969
1970                        let error_message: SharedString = err.to_string().into();
1971                        show_app_notification(
1972                            NotificationId::unique::<ZedUpdateRequiredError>(),
1973                            cx,
1974                            move |cx| {
1975                                cx.new(|cx| {
1976                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1977                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1978                                })
1979                            },
1980                        );
1981                    })
1982                    .ok();
1983                }
1984                Err(err)
1985            }
1986        }
1987    }
1988
1989    async fn send_api_request<Res>(
1990        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1991        client: Arc<Client>,
1992        llm_token: LlmApiToken,
1993        app_version: Version,
1994    ) -> Result<(Res, Option<EditPredictionUsage>)>
1995    where
1996        Res: DeserializeOwned,
1997    {
1998        let http_client = client.http_client();
1999        let mut token = llm_token.acquire(&client).await?;
2000        let mut did_retry = false;
2001
2002        loop {
2003            let request_builder = http_client::Request::builder().method(Method::POST);
2004
2005            let request = build(
2006                request_builder
2007                    .header("Content-Type", "application/json")
2008                    .header("Authorization", format!("Bearer {}", token))
2009                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2010            )?;
2011
2012            let mut response = http_client.send(request).await?;
2013
2014            if let Some(minimum_required_version) = response
2015                .headers()
2016                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2017                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2018            {
2019                anyhow::ensure!(
2020                    app_version >= minimum_required_version,
2021                    ZedUpdateRequiredError {
2022                        minimum_version: minimum_required_version
2023                    }
2024                );
2025            }
2026
2027            if response.status().is_success() {
2028                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2029
2030                let mut body = Vec::new();
2031                response.body_mut().read_to_end(&mut body).await?;
2032                return Ok((serde_json::from_slice(&body)?, usage));
2033            } else if !did_retry
2034                && response
2035                    .headers()
2036                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2037                    .is_some()
2038            {
2039                did_retry = true;
2040                token = llm_token.refresh(&client).await?;
2041            } else {
2042                let mut body = String::new();
2043                response.body_mut().read_to_string(&mut body).await?;
2044                anyhow::bail!(
2045                    "Request failed with status: {:?}\nBody: {}",
2046                    response.status(),
2047                    body
2048                );
2049            }
2050        }
2051    }
2052
2053    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2054    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2055
2056    // Refresh the related excerpts when the user just beguns editing after
2057    // an idle period, and after they pause editing.
2058    fn refresh_context_if_needed(
2059        &mut self,
2060        project: &Entity<Project>,
2061        buffer: &Entity<language::Buffer>,
2062        cursor_position: language::Anchor,
2063        cx: &mut Context<Self>,
2064    ) {
2065        if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2066            return;
2067        }
2068
2069        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2070            return;
2071        }
2072
2073        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2074            return;
2075        };
2076
2077        let now = Instant::now();
2078        let was_idle = zeta_project
2079            .refresh_context_timestamp
2080            .map_or(true, |timestamp| {
2081                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2082            });
2083        zeta_project.refresh_context_timestamp = Some(now);
2084        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2085            let buffer = buffer.clone();
2086            let project = project.clone();
2087            async move |this, cx| {
2088                if was_idle {
2089                    log::debug!("refetching edit prediction context after idle");
2090                } else {
2091                    cx.background_executor()
2092                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2093                        .await;
2094                    log::debug!("refetching edit prediction context after pause");
2095                }
2096                this.update(cx, |this, cx| {
2097                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2098
2099                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2100                        zeta_project.refresh_context_task = Some(task.log_err());
2101                    };
2102                })
2103                .ok()
2104            }
2105        }));
2106    }
2107
2108    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2109    // and avoid spawning more than one concurrent task.
2110    pub fn refresh_context(
2111        &mut self,
2112        project: Entity<Project>,
2113        buffer: Entity<language::Buffer>,
2114        cursor_position: language::Anchor,
2115        cx: &mut Context<Self>,
2116    ) -> Task<Result<()>> {
2117        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2118            return Task::ready(anyhow::Ok(()));
2119        };
2120
2121        let ContextMode::Agentic(options) = &self.options().context else {
2122            return Task::ready(anyhow::Ok(()));
2123        };
2124
2125        let snapshot = buffer.read(cx).snapshot();
2126        let cursor_point = cursor_position.to_point(&snapshot);
2127        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2128            cursor_point,
2129            &snapshot,
2130            &options.excerpt,
2131            None,
2132        ) else {
2133            return Task::ready(Ok(()));
2134        };
2135
2136        let app_version = AppVersion::global(cx);
2137        let client = self.client.clone();
2138        let llm_token = self.llm_token.clone();
2139        let debug_tx = self.debug_tx.clone();
2140        let current_file_path: Arc<Path> = snapshot
2141            .file()
2142            .map(|f| f.full_path(cx).into())
2143            .unwrap_or_else(|| Path::new("untitled").into());
2144
2145        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2146            predict_edits_v3::PlanContextRetrievalRequest {
2147                excerpt: cursor_excerpt.text(&snapshot).body,
2148                excerpt_path: current_file_path,
2149                excerpt_line_range: cursor_excerpt.line_range,
2150                cursor_file_max_row: Line(snapshot.max_point().row),
2151                events: zeta_project.events(cx),
2152            },
2153        ) {
2154            Ok(prompt) => prompt,
2155            Err(err) => {
2156                return Task::ready(Err(err));
2157            }
2158        };
2159
2160        if let Some(debug_tx) = &debug_tx {
2161            debug_tx
2162                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2163                    ZetaContextRetrievalStartedDebugInfo {
2164                        project: project.clone(),
2165                        timestamp: Instant::now(),
2166                        search_prompt: prompt.clone(),
2167                    },
2168                ))
2169                .ok();
2170        }
2171
2172        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2173            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2174                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2175            );
2176
2177            let description = schema
2178                .get("description")
2179                .and_then(|description| description.as_str())
2180                .unwrap()
2181                .to_string();
2182
2183            (schema.into(), description)
2184        });
2185
2186        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2187
2188        let request = open_ai::Request {
2189            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2190            messages: vec![open_ai::RequestMessage::User {
2191                content: open_ai::MessageContent::Plain(prompt),
2192            }],
2193            stream: false,
2194            max_completion_tokens: None,
2195            stop: Default::default(),
2196            temperature: 0.7,
2197            tool_choice: None,
2198            parallel_tool_calls: None,
2199            tools: vec![open_ai::ToolDefinition::Function {
2200                function: FunctionDefinition {
2201                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2202                    description: Some(tool_description),
2203                    parameters: Some(tool_schema),
2204                },
2205            }],
2206            prompt_cache_key: None,
2207            reasoning_effort: None,
2208        };
2209
2210        #[cfg(feature = "eval-support")]
2211        let eval_cache = self.eval_cache.clone();
2212
2213        cx.spawn(async move |this, cx| {
2214            log::trace!("Sending search planning request");
2215            let response = Self::send_raw_llm_request(
2216                request,
2217                client,
2218                llm_token,
2219                app_version,
2220                #[cfg(feature = "eval-support")]
2221                eval_cache.clone(),
2222                #[cfg(feature = "eval-support")]
2223                EvalCacheEntryKind::Context,
2224            )
2225            .await;
2226            let mut response = Self::handle_api_response(&this, response, cx)?;
2227            log::trace!("Got search planning response");
2228
2229            let choice = response
2230                .choices
2231                .pop()
2232                .context("No choices in retrieval response")?;
2233            let open_ai::RequestMessage::Assistant {
2234                content: _,
2235                tool_calls,
2236            } = choice.message
2237            else {
2238                anyhow::bail!("Retrieval response didn't include an assistant message");
2239            };
2240
2241            let mut queries: Vec<SearchToolQuery> = Vec::new();
2242            for tool_call in tool_calls {
2243                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2244                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2245                    log::warn!(
2246                        "Context retrieval response tried to call an unknown tool: {}",
2247                        function.name
2248                    );
2249
2250                    continue;
2251                }
2252
2253                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2254                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2255                queries.extend(input.queries);
2256            }
2257
2258            if let Some(debug_tx) = &debug_tx {
2259                debug_tx
2260                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2261                        ZetaSearchQueryDebugInfo {
2262                            project: project.clone(),
2263                            timestamp: Instant::now(),
2264                            search_queries: queries.clone(),
2265                        },
2266                    ))
2267                    .ok();
2268            }
2269
2270            log::trace!("Running retrieval search: {queries:#?}");
2271
2272            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2273                queries,
2274                project.clone(),
2275                #[cfg(feature = "eval-support")]
2276                eval_cache,
2277                cx,
2278            )
2279            .await;
2280
2281            log::trace!("Search queries executed");
2282
2283            if let Some(debug_tx) = &debug_tx {
2284                debug_tx
2285                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2286                        ZetaContextRetrievalDebugInfo {
2287                            project: project.clone(),
2288                            timestamp: Instant::now(),
2289                        },
2290                    ))
2291                    .ok();
2292            }
2293
2294            this.update(cx, |this, _cx| {
2295                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2296                    return Ok(());
2297                };
2298                zeta_project.refresh_context_task.take();
2299                if let Some(debug_tx) = &this.debug_tx {
2300                    debug_tx
2301                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2302                            ZetaContextRetrievalDebugInfo {
2303                                project,
2304                                timestamp: Instant::now(),
2305                            },
2306                        ))
2307                        .ok();
2308                }
2309                match related_excerpts_result {
2310                    Ok(excerpts) => {
2311                        zeta_project.context = Some(excerpts);
2312                        Ok(())
2313                    }
2314                    Err(error) => Err(error),
2315                }
2316            })?
2317        })
2318    }
2319
2320    pub fn set_context(
2321        &mut self,
2322        project: Entity<Project>,
2323        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2324    ) {
2325        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2326            zeta_project.context = Some(context);
2327        }
2328    }
2329
2330    fn gather_nearby_diagnostics(
2331        cursor_offset: usize,
2332        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2333        snapshot: &BufferSnapshot,
2334        max_diagnostics_bytes: usize,
2335    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2336        // TODO: Could make this more efficient
2337        let mut diagnostic_groups = Vec::new();
2338        for (language_server_id, diagnostics) in diagnostic_sets {
2339            let mut groups = Vec::new();
2340            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2341            diagnostic_groups.extend(
2342                groups
2343                    .into_iter()
2344                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2345            );
2346        }
2347
2348        // sort by proximity to cursor
2349        diagnostic_groups.sort_by_key(|group| {
2350            let range = &group.entries[group.primary_ix].range;
2351            if range.start >= cursor_offset {
2352                range.start - cursor_offset
2353            } else if cursor_offset >= range.end {
2354                cursor_offset - range.end
2355            } else {
2356                (cursor_offset - range.start).min(range.end - cursor_offset)
2357            }
2358        });
2359
2360        let mut results = Vec::new();
2361        let mut diagnostic_groups_truncated = false;
2362        let mut diagnostics_byte_count = 0;
2363        for group in diagnostic_groups {
2364            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2365            diagnostics_byte_count += raw_value.get().len();
2366            if diagnostics_byte_count > max_diagnostics_bytes {
2367                diagnostic_groups_truncated = true;
2368                break;
2369            }
2370            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2371        }
2372
2373        (results, diagnostic_groups_truncated)
2374    }
2375
2376    // TODO: Dedupe with similar code in request_prediction?
2377    pub fn cloud_request_for_zeta_cli(
2378        &mut self,
2379        project: &Entity<Project>,
2380        buffer: &Entity<Buffer>,
2381        position: language::Anchor,
2382        cx: &mut Context<Self>,
2383    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2384        let project_state = self.projects.get(&project.entity_id());
2385
2386        let index_state = project_state.and_then(|state| {
2387            state
2388                .syntax_index
2389                .as_ref()
2390                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2391        });
2392        let options = self.options.clone();
2393        let snapshot = buffer.read(cx).snapshot();
2394        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2395            return Task::ready(Err(anyhow!("No file path for excerpt")));
2396        };
2397        let worktree_snapshots = project
2398            .read(cx)
2399            .worktrees(cx)
2400            .map(|worktree| worktree.read(cx).snapshot())
2401            .collect::<Vec<_>>();
2402
2403        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2404            let mut path = f.worktree.read(cx).absolutize(&f.path);
2405            if path.pop() { Some(path) } else { None }
2406        });
2407
2408        cx.background_spawn(async move {
2409            let index_state = if let Some(index_state) = index_state {
2410                Some(index_state.lock_owned().await)
2411            } else {
2412                None
2413            };
2414
2415            let cursor_point = position.to_point(&snapshot);
2416
2417            let debug_info = true;
2418            EditPredictionContext::gather_context(
2419                cursor_point,
2420                &snapshot,
2421                parent_abs_path.as_deref(),
2422                match &options.context {
2423                    ContextMode::Agentic(_) => {
2424                        // TODO
2425                        panic!("Llm mode not supported in zeta cli yet");
2426                    }
2427                    ContextMode::Syntax(edit_prediction_context_options) => {
2428                        edit_prediction_context_options
2429                    }
2430                },
2431                index_state.as_deref(),
2432            )
2433            .context("Failed to select excerpt")
2434            .map(|context| {
2435                make_syntax_context_cloud_request(
2436                    excerpt_path.into(),
2437                    context,
2438                    // TODO pass everything
2439                    Vec::new(),
2440                    false,
2441                    Vec::new(),
2442                    false,
2443                    None,
2444                    debug_info,
2445                    &worktree_snapshots,
2446                    index_state.as_deref(),
2447                    Some(options.max_prompt_bytes),
2448                    options.prompt_format,
2449                    PredictEditsRequestTrigger::Other,
2450                )
2451            })
2452        })
2453    }
2454
2455    pub fn wait_for_initial_indexing(
2456        &mut self,
2457        project: &Entity<Project>,
2458        cx: &mut Context<Self>,
2459    ) -> Task<Result<()>> {
2460        let zeta_project = self.get_or_init_zeta_project(project, cx);
2461        if let Some(syntax_index) = &zeta_project.syntax_index {
2462            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2463        } else {
2464            Task::ready(Ok(()))
2465        }
2466    }
2467
2468    fn is_file_open_source(
2469        &self,
2470        project: &Entity<Project>,
2471        file: &Arc<dyn File>,
2472        cx: &App,
2473    ) -> bool {
2474        if !file.is_local() || file.is_private() {
2475            return false;
2476        }
2477        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2478            return false;
2479        };
2480        zeta_project
2481            .license_detection_watchers
2482            .get(&file.worktree_id(cx))
2483            .as_ref()
2484            .is_some_and(|watcher| watcher.is_project_open_source())
2485    }
2486
2487    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2488        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2489    }
2490
2491    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2492        if !self.data_collection_choice.is_enabled() {
2493            return false;
2494        }
2495        events.iter().all(|event| {
2496            matches!(
2497                event.as_ref(),
2498                Event::BufferChange {
2499                    in_open_source_repo: true,
2500                    ..
2501                }
2502            )
2503        })
2504    }
2505
2506    fn load_data_collection_choice() -> DataCollectionChoice {
2507        let choice = KEY_VALUE_STORE
2508            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2509            .log_err()
2510            .flatten();
2511
2512        match choice.as_deref() {
2513            Some("true") => DataCollectionChoice::Enabled,
2514            Some("false") => DataCollectionChoice::Disabled,
2515            Some(_) => {
2516                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2517                DataCollectionChoice::NotAnswered
2518            }
2519            None => DataCollectionChoice::NotAnswered,
2520        }
2521    }
2522
2523    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2524        self.shown_predictions.iter()
2525    }
2526
2527    pub fn shown_completions_len(&self) -> usize {
2528        self.shown_predictions.len()
2529    }
2530
2531    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2532        self.rated_predictions.contains(id)
2533    }
2534
2535    pub fn rate_prediction(
2536        &mut self,
2537        prediction: &EditPrediction,
2538        rating: EditPredictionRating,
2539        feedback: String,
2540        cx: &mut Context<Self>,
2541    ) {
2542        self.rated_predictions.insert(prediction.id.clone());
2543        telemetry::event!(
2544            "Edit Prediction Rated",
2545            rating,
2546            inputs = prediction.inputs,
2547            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2548            feedback
2549        );
2550        self.client.telemetry().flush_events().detach();
2551        cx.notify();
2552    }
2553}
2554
2555pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2556    let choice = res.choices.pop()?;
2557    let output_text = match choice.message {
2558        open_ai::RequestMessage::Assistant {
2559            content: Some(open_ai::MessageContent::Plain(content)),
2560            ..
2561        } => content,
2562        open_ai::RequestMessage::Assistant {
2563            content: Some(open_ai::MessageContent::Multipart(mut content)),
2564            ..
2565        } => {
2566            if content.is_empty() {
2567                log::error!("No output from Baseten completion response");
2568                return None;
2569            }
2570
2571            match content.remove(0) {
2572                open_ai::MessagePart::Text { text } => text,
2573                open_ai::MessagePart::Image { .. } => {
2574                    log::error!("Expected text, got an image");
2575                    return None;
2576                }
2577            }
2578        }
2579        _ => {
2580            log::error!("Invalid response message: {:?}", choice.message);
2581            return None;
2582        }
2583    };
2584    Some(output_text)
2585}
2586
2587#[derive(Error, Debug)]
2588#[error(
2589    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2590)]
2591pub struct ZedUpdateRequiredError {
2592    minimum_version: Version,
2593}
2594
2595fn make_syntax_context_cloud_request(
2596    excerpt_path: Arc<Path>,
2597    context: EditPredictionContext,
2598    events: Vec<Arc<predict_edits_v3::Event>>,
2599    can_collect_data: bool,
2600    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2601    diagnostic_groups_truncated: bool,
2602    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2603    debug_info: bool,
2604    worktrees: &Vec<worktree::Snapshot>,
2605    index_state: Option<&SyntaxIndexState>,
2606    prompt_max_bytes: Option<usize>,
2607    prompt_format: PromptFormat,
2608    trigger: PredictEditsRequestTrigger,
2609) -> predict_edits_v3::PredictEditsRequest {
2610    let mut signatures = Vec::new();
2611    let mut declaration_to_signature_index = HashMap::default();
2612    let mut referenced_declarations = Vec::new();
2613
2614    for snippet in context.declarations {
2615        let project_entry_id = snippet.declaration.project_entry_id();
2616        let Some(path) = worktrees.iter().find_map(|worktree| {
2617            worktree.entry_for_id(project_entry_id).map(|entry| {
2618                let mut full_path = RelPathBuf::new();
2619                full_path.push(worktree.root_name());
2620                full_path.push(&entry.path);
2621                full_path
2622            })
2623        }) else {
2624            continue;
2625        };
2626
2627        let parent_index = index_state.and_then(|index_state| {
2628            snippet.declaration.parent().and_then(|parent| {
2629                add_signature(
2630                    parent,
2631                    &mut declaration_to_signature_index,
2632                    &mut signatures,
2633                    index_state,
2634                )
2635            })
2636        });
2637
2638        let (text, text_is_truncated) = snippet.declaration.item_text();
2639        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2640            path: path.as_std_path().into(),
2641            text: text.into(),
2642            range: snippet.declaration.item_line_range(),
2643            text_is_truncated,
2644            signature_range: snippet.declaration.signature_range_in_item_text(),
2645            parent_index,
2646            signature_score: snippet.score(DeclarationStyle::Signature),
2647            declaration_score: snippet.score(DeclarationStyle::Declaration),
2648            score_components: snippet.components,
2649        });
2650    }
2651
2652    let excerpt_parent = index_state.and_then(|index_state| {
2653        context
2654            .excerpt
2655            .parent_declarations
2656            .last()
2657            .and_then(|(parent, _)| {
2658                add_signature(
2659                    *parent,
2660                    &mut declaration_to_signature_index,
2661                    &mut signatures,
2662                    index_state,
2663                )
2664            })
2665    });
2666
2667    predict_edits_v3::PredictEditsRequest {
2668        excerpt_path,
2669        excerpt: context.excerpt_text.body,
2670        excerpt_line_range: context.excerpt.line_range,
2671        excerpt_range: context.excerpt.range,
2672        cursor_point: predict_edits_v3::Point {
2673            line: predict_edits_v3::Line(context.cursor_point.row),
2674            column: context.cursor_point.column,
2675        },
2676        referenced_declarations,
2677        included_files: vec![],
2678        signatures,
2679        excerpt_parent,
2680        events,
2681        can_collect_data,
2682        diagnostic_groups,
2683        diagnostic_groups_truncated,
2684        git_info,
2685        debug_info,
2686        prompt_max_bytes,
2687        prompt_format,
2688        trigger,
2689    }
2690}
2691
2692fn add_signature(
2693    declaration_id: DeclarationId,
2694    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2695    signatures: &mut Vec<Signature>,
2696    index: &SyntaxIndexState,
2697) -> Option<usize> {
2698    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2699        return Some(*signature_index);
2700    }
2701    let Some(parent_declaration) = index.declaration(declaration_id) else {
2702        log::error!("bug: missing parent declaration");
2703        return None;
2704    };
2705    let parent_index = parent_declaration.parent().and_then(|parent| {
2706        add_signature(parent, declaration_to_signature_index, signatures, index)
2707    });
2708    let (text, text_is_truncated) = parent_declaration.signature_text();
2709    let signature_index = signatures.len();
2710    signatures.push(Signature {
2711        text: text.into(),
2712        text_is_truncated,
2713        parent_index,
2714        range: parent_declaration.signature_line_range(),
2715    });
2716    declaration_to_signature_index.insert(declaration_id, signature_index);
2717    Some(signature_index)
2718}
2719
2720#[cfg(feature = "eval-support")]
2721pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2722
2723#[cfg(feature = "eval-support")]
2724#[derive(Debug, Clone, Copy, PartialEq)]
2725pub enum EvalCacheEntryKind {
2726    Context,
2727    Search,
2728    Prediction,
2729}
2730
2731#[cfg(feature = "eval-support")]
2732impl std::fmt::Display for EvalCacheEntryKind {
2733    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2734        match self {
2735            EvalCacheEntryKind::Search => write!(f, "search"),
2736            EvalCacheEntryKind::Context => write!(f, "context"),
2737            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2738        }
2739    }
2740}
2741
2742#[cfg(feature = "eval-support")]
2743pub trait EvalCache: Send + Sync {
2744    fn read(&self, key: EvalCacheKey) -> Option<String>;
2745    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2746}
2747
2748#[derive(Debug, Clone, Copy)]
2749pub enum DataCollectionChoice {
2750    NotAnswered,
2751    Enabled,
2752    Disabled,
2753}
2754
2755impl DataCollectionChoice {
2756    pub fn is_enabled(self) -> bool {
2757        match self {
2758            Self::Enabled => true,
2759            Self::NotAnswered | Self::Disabled => false,
2760        }
2761    }
2762
2763    pub fn is_answered(self) -> bool {
2764        match self {
2765            Self::Enabled | Self::Disabled => true,
2766            Self::NotAnswered => false,
2767        }
2768    }
2769
2770    #[must_use]
2771    pub fn toggle(&self) -> DataCollectionChoice {
2772        match self {
2773            Self::Enabled => Self::Disabled,
2774            Self::Disabled => Self::Enabled,
2775            Self::NotAnswered => Self::Enabled,
2776        }
2777    }
2778}
2779
2780impl From<bool> for DataCollectionChoice {
2781    fn from(value: bool) -> Self {
2782        match value {
2783            true => DataCollectionChoice::Enabled,
2784            false => DataCollectionChoice::Disabled,
2785        }
2786    }
2787}
2788
2789struct ZedPredictUpsell;
2790
2791impl Dismissable for ZedPredictUpsell {
2792    const KEY: &'static str = "dismissed-edit-predict-upsell";
2793
2794    fn dismissed() -> bool {
2795        // To make this backwards compatible with older versions of Zed, we
2796        // check if the user has seen the previous Edit Prediction Onboarding
2797        // before, by checking the data collection choice which was written to
2798        // the database once the user clicked on "Accept and Enable"
2799        if KEY_VALUE_STORE
2800            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2801            .log_err()
2802            .is_some_and(|s| s.is_some())
2803        {
2804            return true;
2805        }
2806
2807        KEY_VALUE_STORE
2808            .read_kvp(Self::KEY)
2809            .log_err()
2810            .is_some_and(|s| s.is_some())
2811    }
2812}
2813
2814pub fn should_show_upsell_modal() -> bool {
2815    !ZedPredictUpsell::dismissed()
2816}
2817
2818pub fn init(cx: &mut App) {
2819    feature_gate_predict_edits_actions(cx);
2820
2821    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2822        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2823            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2824                RatePredictionsModal::toggle(workspace, window, cx);
2825            }
2826        });
2827
2828        workspace.register_action(
2829            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2830                ZedPredictModal::toggle(
2831                    workspace,
2832                    workspace.user_store().clone(),
2833                    workspace.client().clone(),
2834                    window,
2835                    cx,
2836                )
2837            },
2838        );
2839
2840        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2841            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2842                settings
2843                    .project
2844                    .all_languages
2845                    .features
2846                    .get_or_insert_default()
2847                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2848            });
2849        });
2850    })
2851    .detach();
2852}
2853
2854fn feature_gate_predict_edits_actions(cx: &mut App) {
2855    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2856    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2857    let zeta_all_action_types = [
2858        TypeId::of::<RateCompletions>(),
2859        TypeId::of::<ResetOnboarding>(),
2860        zed_actions::OpenZedPredictOnboarding.type_id(),
2861        TypeId::of::<ClearHistory>(),
2862        TypeId::of::<ThumbsUpActivePrediction>(),
2863        TypeId::of::<ThumbsDownActivePrediction>(),
2864        TypeId::of::<NextEdit>(),
2865        TypeId::of::<PreviousEdit>(),
2866    ];
2867
2868    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2869        filter.hide_action_types(&rate_completion_action_types);
2870        filter.hide_action_types(&reset_onboarding_action_types);
2871        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2872    });
2873
2874    cx.observe_global::<SettingsStore>(move |cx| {
2875        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2876        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2877
2878        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2879            if is_ai_disabled {
2880                filter.hide_action_types(&zeta_all_action_types);
2881            } else if has_feature_flag {
2882                filter.show_action_types(&rate_completion_action_types);
2883            } else {
2884                filter.hide_action_types(&rate_completion_action_types);
2885            }
2886        });
2887    })
2888    .detach();
2889
2890    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2891        if !DisableAiSettings::get_global(cx).disable_ai {
2892            if is_enabled {
2893                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2894                    filter.show_action_types(&rate_completion_action_types);
2895                });
2896            } else {
2897                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2898                    filter.hide_action_types(&rate_completion_action_types);
2899                });
2900            }
2901        }
2902    })
2903    .detach();
2904}
2905
2906#[cfg(test)]
2907mod tests {
2908    use std::{path::Path, sync::Arc, time::Duration};
2909
2910    use client::UserStore;
2911    use clock::FakeSystemClock;
2912    use cloud_llm_client::{
2913        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2914    };
2915    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2916    use futures::{
2917        AsyncReadExt, StreamExt,
2918        channel::{mpsc, oneshot},
2919    };
2920    use gpui::{
2921        Entity, TestAppContext,
2922        http_client::{FakeHttpClient, Response},
2923        prelude::*,
2924    };
2925    use indoc::indoc;
2926    use language::OffsetRangeExt as _;
2927    use open_ai::Usage;
2928    use pretty_assertions::{assert_eq, assert_matches};
2929    use project::{FakeFs, Project};
2930    use serde_json::json;
2931    use settings::SettingsStore;
2932    use util::path;
2933    use uuid::Uuid;
2934
2935    use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
2936
2937    #[gpui::test]
2938    async fn test_current_state(cx: &mut TestAppContext) {
2939        let (zeta, mut requests) = init_test(cx);
2940        let fs = FakeFs::new(cx.executor());
2941        fs.insert_tree(
2942            "/root",
2943            json!({
2944                "1.txt": "Hello!\nHow\nBye\n",
2945                "2.txt": "Hola!\nComo\nAdios\n"
2946            }),
2947        )
2948        .await;
2949        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2950
2951        zeta.update(cx, |zeta, cx| {
2952            zeta.register_project(&project, cx);
2953        });
2954
2955        let buffer1 = project
2956            .update(cx, |project, cx| {
2957                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2958                project.open_buffer(path, cx)
2959            })
2960            .await
2961            .unwrap();
2962        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2963        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2964
2965        // Prediction for current file
2966
2967        zeta.update(cx, |zeta, cx| {
2968            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2969        });
2970        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2971
2972        respond_tx
2973            .send(model_response(indoc! {r"
2974                --- a/root/1.txt
2975                +++ b/root/1.txt
2976                @@ ... @@
2977                 Hello!
2978                -How
2979                +How are you?
2980                 Bye
2981            "}))
2982            .unwrap();
2983
2984        cx.run_until_parked();
2985
2986        zeta.read_with(cx, |zeta, cx| {
2987            let prediction = zeta
2988                .current_prediction_for_buffer(&buffer1, &project, cx)
2989                .unwrap();
2990            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2991        });
2992
2993        // Context refresh
2994        let refresh_task = zeta.update(cx, |zeta, cx| {
2995            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2996        });
2997        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2998        respond_tx
2999            .send(open_ai::Response {
3000                id: Uuid::new_v4().to_string(),
3001                object: "response".into(),
3002                created: 0,
3003                model: "model".into(),
3004                choices: vec![open_ai::Choice {
3005                    index: 0,
3006                    message: open_ai::RequestMessage::Assistant {
3007                        content: None,
3008                        tool_calls: vec![open_ai::ToolCall {
3009                            id: "search".into(),
3010                            content: open_ai::ToolCallContent::Function {
3011                                function: open_ai::FunctionContent {
3012                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3013                                        .to_string(),
3014                                    arguments: serde_json::to_string(&SearchToolInput {
3015                                        queries: Box::new([SearchToolQuery {
3016                                            glob: "root/2.txt".to_string(),
3017                                            syntax_node: vec![],
3018                                            content: Some(".".into()),
3019                                        }]),
3020                                    })
3021                                    .unwrap(),
3022                                },
3023                            },
3024                        }],
3025                    },
3026                    finish_reason: None,
3027                }],
3028                usage: Usage {
3029                    prompt_tokens: 0,
3030                    completion_tokens: 0,
3031                    total_tokens: 0,
3032                },
3033            })
3034            .unwrap();
3035        refresh_task.await.unwrap();
3036
3037        zeta.update(cx, |zeta, _cx| {
3038            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
3039        });
3040
3041        // Prediction for another file
3042        zeta.update(cx, |zeta, cx| {
3043            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3044        });
3045        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3046        respond_tx
3047            .send(model_response(indoc! {r#"
3048                --- a/root/2.txt
3049                +++ b/root/2.txt
3050                 Hola!
3051                -Como
3052                +Como estas?
3053                 Adios
3054            "#}))
3055            .unwrap();
3056        cx.run_until_parked();
3057
3058        zeta.read_with(cx, |zeta, cx| {
3059            let prediction = zeta
3060                .current_prediction_for_buffer(&buffer1, &project, cx)
3061                .unwrap();
3062            assert_matches!(
3063                prediction,
3064                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3065            );
3066        });
3067
3068        let buffer2 = project
3069            .update(cx, |project, cx| {
3070                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3071                project.open_buffer(path, cx)
3072            })
3073            .await
3074            .unwrap();
3075
3076        zeta.read_with(cx, |zeta, cx| {
3077            let prediction = zeta
3078                .current_prediction_for_buffer(&buffer2, &project, cx)
3079                .unwrap();
3080            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3081        });
3082    }
3083
3084    #[gpui::test]
3085    async fn test_simple_request(cx: &mut TestAppContext) {
3086        let (zeta, mut requests) = init_test(cx);
3087        let fs = FakeFs::new(cx.executor());
3088        fs.insert_tree(
3089            "/root",
3090            json!({
3091                "foo.md":  "Hello!\nHow\nBye\n"
3092            }),
3093        )
3094        .await;
3095        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3096
3097        let buffer = project
3098            .update(cx, |project, cx| {
3099                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3100                project.open_buffer(path, cx)
3101            })
3102            .await
3103            .unwrap();
3104        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3105        let position = snapshot.anchor_before(language::Point::new(1, 3));
3106
3107        let prediction_task = zeta.update(cx, |zeta, cx| {
3108            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3109        });
3110
3111        let (_, respond_tx) = requests.predict.next().await.unwrap();
3112
3113        // TODO Put back when we have a structured request again
3114        // assert_eq!(
3115        //     request.excerpt_path.as_ref(),
3116        //     Path::new(path!("root/foo.md"))
3117        // );
3118        // assert_eq!(
3119        //     request.cursor_point,
3120        //     Point {
3121        //         line: Line(1),
3122        //         column: 3
3123        //     }
3124        // );
3125
3126        respond_tx
3127            .send(model_response(indoc! { r"
3128                --- a/root/foo.md
3129                +++ b/root/foo.md
3130                @@ ... @@
3131                 Hello!
3132                -How
3133                +How are you?
3134                 Bye
3135            "}))
3136            .unwrap();
3137
3138        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3139
3140        assert_eq!(prediction.edits.len(), 1);
3141        assert_eq!(
3142            prediction.edits[0].0.to_point(&snapshot).start,
3143            language::Point::new(1, 3)
3144        );
3145        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3146    }
3147
3148    #[gpui::test]
3149    async fn test_request_events(cx: &mut TestAppContext) {
3150        let (zeta, mut requests) = init_test(cx);
3151        let fs = FakeFs::new(cx.executor());
3152        fs.insert_tree(
3153            "/root",
3154            json!({
3155                "foo.md": "Hello!\n\nBye\n"
3156            }),
3157        )
3158        .await;
3159        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3160
3161        let buffer = project
3162            .update(cx, |project, cx| {
3163                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3164                project.open_buffer(path, cx)
3165            })
3166            .await
3167            .unwrap();
3168
3169        zeta.update(cx, |zeta, cx| {
3170            zeta.register_buffer(&buffer, &project, cx);
3171        });
3172
3173        buffer.update(cx, |buffer, cx| {
3174            buffer.edit(vec![(7..7, "How")], None, cx);
3175        });
3176
3177        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3178        let position = snapshot.anchor_before(language::Point::new(1, 3));
3179
3180        let prediction_task = zeta.update(cx, |zeta, cx| {
3181            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3182        });
3183
3184        let (request, respond_tx) = requests.predict.next().await.unwrap();
3185
3186        let prompt = prompt_from_request(&request);
3187        assert!(
3188            prompt.contains(indoc! {"
3189            --- a/root/foo.md
3190            +++ b/root/foo.md
3191            @@ -1,3 +1,3 @@
3192             Hello!
3193            -
3194            +How
3195             Bye
3196        "}),
3197            "{prompt}"
3198        );
3199
3200        respond_tx
3201            .send(model_response(indoc! {r#"
3202                --- a/root/foo.md
3203                +++ b/root/foo.md
3204                @@ ... @@
3205                 Hello!
3206                -How
3207                +How are you?
3208                 Bye
3209            "#}))
3210            .unwrap();
3211
3212        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3213
3214        assert_eq!(prediction.edits.len(), 1);
3215        assert_eq!(
3216            prediction.edits[0].0.to_point(&snapshot).start,
3217            language::Point::new(1, 3)
3218        );
3219        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3220    }
3221
3222    #[gpui::test]
3223    async fn test_empty_prediction(cx: &mut TestAppContext) {
3224        let (zeta, mut requests) = init_test(cx);
3225        let fs = FakeFs::new(cx.executor());
3226        fs.insert_tree(
3227            "/root",
3228            json!({
3229                "foo.md":  "Hello!\nHow\nBye\n"
3230            }),
3231        )
3232        .await;
3233        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3234
3235        let buffer = project
3236            .update(cx, |project, cx| {
3237                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3238                project.open_buffer(path, cx)
3239            })
3240            .await
3241            .unwrap();
3242        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3243        let position = snapshot.anchor_before(language::Point::new(1, 3));
3244
3245        zeta.update(cx, |zeta, cx| {
3246            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3247        });
3248
3249        const NO_OP_DIFF: &str = indoc! { r"
3250            --- a/root/foo.md
3251            +++ b/root/foo.md
3252            @@ ... @@
3253             Hello!
3254            -How
3255            +How
3256             Bye
3257        "};
3258
3259        let (_, respond_tx) = requests.predict.next().await.unwrap();
3260        let response = model_response(NO_OP_DIFF);
3261        let id = response.id.clone();
3262        respond_tx.send(response).unwrap();
3263
3264        cx.run_until_parked();
3265
3266        zeta.read_with(cx, |zeta, cx| {
3267            assert!(
3268                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3269                    .is_none()
3270            );
3271        });
3272
3273        // prediction is reported as rejected
3274        let (reject_request, _) = requests.reject.next().await.unwrap();
3275
3276        assert_eq!(
3277            &reject_request.rejections,
3278            &[EditPredictionRejection {
3279                request_id: id,
3280                reason: EditPredictionRejectReason::Empty,
3281                was_shown: false
3282            }]
3283        );
3284    }
3285
3286    #[gpui::test]
3287    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3288        let (zeta, mut requests) = init_test(cx);
3289        let fs = FakeFs::new(cx.executor());
3290        fs.insert_tree(
3291            "/root",
3292            json!({
3293                "foo.md":  "Hello!\nHow\nBye\n"
3294            }),
3295        )
3296        .await;
3297        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3298
3299        let buffer = project
3300            .update(cx, |project, cx| {
3301                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3302                project.open_buffer(path, cx)
3303            })
3304            .await
3305            .unwrap();
3306        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3307        let position = snapshot.anchor_before(language::Point::new(1, 3));
3308
3309        zeta.update(cx, |zeta, cx| {
3310            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3311        });
3312
3313        let (_, respond_tx) = requests.predict.next().await.unwrap();
3314
3315        buffer.update(cx, |buffer, cx| {
3316            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3317        });
3318
3319        let response = model_response(SIMPLE_DIFF);
3320        let id = response.id.clone();
3321        respond_tx.send(response).unwrap();
3322
3323        cx.run_until_parked();
3324
3325        zeta.read_with(cx, |zeta, cx| {
3326            assert!(
3327                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3328                    .is_none()
3329            );
3330        });
3331
3332        // prediction is reported as rejected
3333        let (reject_request, _) = requests.reject.next().await.unwrap();
3334
3335        assert_eq!(
3336            &reject_request.rejections,
3337            &[EditPredictionRejection {
3338                request_id: id,
3339                reason: EditPredictionRejectReason::InterpolatedEmpty,
3340                was_shown: false
3341            }]
3342        );
3343    }
3344
3345    const SIMPLE_DIFF: &str = indoc! { r"
3346        --- a/root/foo.md
3347        +++ b/root/foo.md
3348        @@ ... @@
3349         Hello!
3350        -How
3351        +How are you?
3352         Bye
3353    "};
3354
3355    #[gpui::test]
3356    async fn test_replace_current(cx: &mut TestAppContext) {
3357        let (zeta, mut requests) = init_test(cx);
3358        let fs = FakeFs::new(cx.executor());
3359        fs.insert_tree(
3360            "/root",
3361            json!({
3362                "foo.md":  "Hello!\nHow\nBye\n"
3363            }),
3364        )
3365        .await;
3366        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3367
3368        let buffer = project
3369            .update(cx, |project, cx| {
3370                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3371                project.open_buffer(path, cx)
3372            })
3373            .await
3374            .unwrap();
3375        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3376        let position = snapshot.anchor_before(language::Point::new(1, 3));
3377
3378        zeta.update(cx, |zeta, cx| {
3379            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3380        });
3381
3382        let (_, respond_tx) = requests.predict.next().await.unwrap();
3383        let first_response = model_response(SIMPLE_DIFF);
3384        let first_id = first_response.id.clone();
3385        respond_tx.send(first_response).unwrap();
3386
3387        cx.run_until_parked();
3388
3389        zeta.read_with(cx, |zeta, cx| {
3390            assert_eq!(
3391                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3392                    .unwrap()
3393                    .id
3394                    .0,
3395                first_id
3396            );
3397        });
3398
3399        // a second request is triggered
3400        zeta.update(cx, |zeta, cx| {
3401            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3402        });
3403
3404        let (_, respond_tx) = requests.predict.next().await.unwrap();
3405        let second_response = model_response(SIMPLE_DIFF);
3406        let second_id = second_response.id.clone();
3407        respond_tx.send(second_response).unwrap();
3408
3409        cx.run_until_parked();
3410
3411        zeta.read_with(cx, |zeta, cx| {
3412            // second replaces first
3413            assert_eq!(
3414                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3415                    .unwrap()
3416                    .id
3417                    .0,
3418                second_id
3419            );
3420        });
3421
3422        // first is reported as replaced
3423        let (reject_request, _) = requests.reject.next().await.unwrap();
3424
3425        assert_eq!(
3426            &reject_request.rejections,
3427            &[EditPredictionRejection {
3428                request_id: first_id,
3429                reason: EditPredictionRejectReason::Replaced,
3430                was_shown: false
3431            }]
3432        );
3433    }
3434
3435    #[gpui::test]
3436    async fn test_current_preferred(cx: &mut TestAppContext) {
3437        let (zeta, mut requests) = init_test(cx);
3438        let fs = FakeFs::new(cx.executor());
3439        fs.insert_tree(
3440            "/root",
3441            json!({
3442                "foo.md":  "Hello!\nHow\nBye\n"
3443            }),
3444        )
3445        .await;
3446        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3447
3448        let buffer = project
3449            .update(cx, |project, cx| {
3450                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3451                project.open_buffer(path, cx)
3452            })
3453            .await
3454            .unwrap();
3455        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3456        let position = snapshot.anchor_before(language::Point::new(1, 3));
3457
3458        zeta.update(cx, |zeta, cx| {
3459            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3460        });
3461
3462        let (_, respond_tx) = requests.predict.next().await.unwrap();
3463        let first_response = model_response(SIMPLE_DIFF);
3464        let first_id = first_response.id.clone();
3465        respond_tx.send(first_response).unwrap();
3466
3467        cx.run_until_parked();
3468
3469        zeta.read_with(cx, |zeta, cx| {
3470            assert_eq!(
3471                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3472                    .unwrap()
3473                    .id
3474                    .0,
3475                first_id
3476            );
3477        });
3478
3479        // a second request is triggered
3480        zeta.update(cx, |zeta, cx| {
3481            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3482        });
3483
3484        let (_, respond_tx) = requests.predict.next().await.unwrap();
3485        // worse than current prediction
3486        let second_response = model_response(indoc! { r"
3487            --- a/root/foo.md
3488            +++ b/root/foo.md
3489            @@ ... @@
3490             Hello!
3491            -How
3492            +How are
3493             Bye
3494        "});
3495        let second_id = second_response.id.clone();
3496        respond_tx.send(second_response).unwrap();
3497
3498        cx.run_until_parked();
3499
3500        zeta.read_with(cx, |zeta, cx| {
3501            // first is preferred over second
3502            assert_eq!(
3503                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3504                    .unwrap()
3505                    .id
3506                    .0,
3507                first_id
3508            );
3509        });
3510
3511        // second is reported as rejected
3512        let (reject_request, _) = requests.reject.next().await.unwrap();
3513
3514        assert_eq!(
3515            &reject_request.rejections,
3516            &[EditPredictionRejection {
3517                request_id: second_id,
3518                reason: EditPredictionRejectReason::CurrentPreferred,
3519                was_shown: false
3520            }]
3521        );
3522    }
3523
3524    #[gpui::test]
3525    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3526        let (zeta, mut requests) = init_test(cx);
3527        let fs = FakeFs::new(cx.executor());
3528        fs.insert_tree(
3529            "/root",
3530            json!({
3531                "foo.md":  "Hello!\nHow\nBye\n"
3532            }),
3533        )
3534        .await;
3535        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3536
3537        let buffer = project
3538            .update(cx, |project, cx| {
3539                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3540                project.open_buffer(path, cx)
3541            })
3542            .await
3543            .unwrap();
3544        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3545        let position = snapshot.anchor_before(language::Point::new(1, 3));
3546
3547        // start two refresh tasks
3548        zeta.update(cx, |zeta, cx| {
3549            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3550        });
3551
3552        let (_, respond_first) = requests.predict.next().await.unwrap();
3553
3554        zeta.update(cx, |zeta, cx| {
3555            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3556        });
3557
3558        let (_, respond_second) = requests.predict.next().await.unwrap();
3559
3560        // wait for throttle
3561        cx.run_until_parked();
3562
3563        // second responds first
3564        let second_response = model_response(SIMPLE_DIFF);
3565        let second_id = second_response.id.clone();
3566        respond_second.send(second_response).unwrap();
3567
3568        cx.run_until_parked();
3569
3570        zeta.read_with(cx, |zeta, cx| {
3571            // current prediction is second
3572            assert_eq!(
3573                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3574                    .unwrap()
3575                    .id
3576                    .0,
3577                second_id
3578            );
3579        });
3580
3581        let first_response = model_response(SIMPLE_DIFF);
3582        let first_id = first_response.id.clone();
3583        respond_first.send(first_response).unwrap();
3584
3585        cx.run_until_parked();
3586
3587        zeta.read_with(cx, |zeta, cx| {
3588            // current prediction is still second, since first was cancelled
3589            assert_eq!(
3590                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3591                    .unwrap()
3592                    .id
3593                    .0,
3594                second_id
3595            );
3596        });
3597
3598        // first is reported as rejected
3599        let (reject_request, _) = requests.reject.next().await.unwrap();
3600
3601        cx.run_until_parked();
3602
3603        assert_eq!(
3604            &reject_request.rejections,
3605            &[EditPredictionRejection {
3606                request_id: first_id,
3607                reason: EditPredictionRejectReason::Canceled,
3608                was_shown: false
3609            }]
3610        );
3611    }
3612
3613    #[gpui::test]
3614    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3615        let (zeta, mut requests) = init_test(cx);
3616        let fs = FakeFs::new(cx.executor());
3617        fs.insert_tree(
3618            "/root",
3619            json!({
3620                "foo.md":  "Hello!\nHow\nBye\n"
3621            }),
3622        )
3623        .await;
3624        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3625
3626        let buffer = project
3627            .update(cx, |project, cx| {
3628                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3629                project.open_buffer(path, cx)
3630            })
3631            .await
3632            .unwrap();
3633        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3634        let position = snapshot.anchor_before(language::Point::new(1, 3));
3635
3636        // start two refresh tasks
3637        zeta.update(cx, |zeta, cx| {
3638            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3639        });
3640
3641        let (_, respond_first) = requests.predict.next().await.unwrap();
3642
3643        zeta.update(cx, |zeta, cx| {
3644            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3645        });
3646
3647        let (_, respond_second) = requests.predict.next().await.unwrap();
3648
3649        // wait for throttle, so requests are sent
3650        cx.run_until_parked();
3651
3652        zeta.update(cx, |zeta, cx| {
3653            // start a third request
3654            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3655
3656            // 2 are pending, so 2nd is cancelled
3657            assert_eq!(
3658                zeta.get_or_init_zeta_project(&project, cx)
3659                    .cancelled_predictions
3660                    .iter()
3661                    .copied()
3662                    .collect::<Vec<_>>(),
3663                [1]
3664            );
3665        });
3666
3667        // wait for throttle
3668        cx.run_until_parked();
3669
3670        let (_, respond_third) = requests.predict.next().await.unwrap();
3671
3672        let first_response = model_response(SIMPLE_DIFF);
3673        let first_id = first_response.id.clone();
3674        respond_first.send(first_response).unwrap();
3675
3676        cx.run_until_parked();
3677
3678        zeta.read_with(cx, |zeta, cx| {
3679            // current prediction is first
3680            assert_eq!(
3681                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3682                    .unwrap()
3683                    .id
3684                    .0,
3685                first_id
3686            );
3687        });
3688
3689        let cancelled_response = model_response(SIMPLE_DIFF);
3690        let cancelled_id = cancelled_response.id.clone();
3691        respond_second.send(cancelled_response).unwrap();
3692
3693        cx.run_until_parked();
3694
3695        zeta.read_with(cx, |zeta, cx| {
3696            // current prediction is still first, since second was cancelled
3697            assert_eq!(
3698                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3699                    .unwrap()
3700                    .id
3701                    .0,
3702                first_id
3703            );
3704        });
3705
3706        let third_response = model_response(SIMPLE_DIFF);
3707        let third_response_id = third_response.id.clone();
3708        respond_third.send(third_response).unwrap();
3709
3710        cx.run_until_parked();
3711
3712        zeta.read_with(cx, |zeta, cx| {
3713            // third completes and replaces first
3714            assert_eq!(
3715                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3716                    .unwrap()
3717                    .id
3718                    .0,
3719                third_response_id
3720            );
3721        });
3722
3723        // second is reported as rejected
3724        let (reject_request, _) = requests.reject.next().await.unwrap();
3725
3726        cx.run_until_parked();
3727
3728        assert_eq!(
3729            &reject_request.rejections,
3730            &[
3731                EditPredictionRejection {
3732                    request_id: cancelled_id,
3733                    reason: EditPredictionRejectReason::Canceled,
3734                    was_shown: false
3735                },
3736                EditPredictionRejection {
3737                    request_id: first_id,
3738                    reason: EditPredictionRejectReason::Replaced,
3739                    was_shown: false
3740                }
3741            ]
3742        );
3743    }
3744
3745    #[gpui::test]
3746    async fn test_rejections_flushing(cx: &mut TestAppContext) {
3747        let (zeta, mut requests) = init_test(cx);
3748
3749        zeta.update(cx, |zeta, _cx| {
3750            zeta.reject_prediction(
3751                EditPredictionId("test-1".into()),
3752                EditPredictionRejectReason::Discarded,
3753                false,
3754            );
3755            zeta.reject_prediction(
3756                EditPredictionId("test-2".into()),
3757                EditPredictionRejectReason::Canceled,
3758                true,
3759            );
3760        });
3761
3762        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3763        cx.run_until_parked();
3764
3765        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3766        respond_tx.send(()).unwrap();
3767
3768        // batched
3769        assert_eq!(reject_request.rejections.len(), 2);
3770        assert_eq!(
3771            reject_request.rejections[0],
3772            EditPredictionRejection {
3773                request_id: "test-1".to_string(),
3774                reason: EditPredictionRejectReason::Discarded,
3775                was_shown: false
3776            }
3777        );
3778        assert_eq!(
3779            reject_request.rejections[1],
3780            EditPredictionRejection {
3781                request_id: "test-2".to_string(),
3782                reason: EditPredictionRejectReason::Canceled,
3783                was_shown: true
3784            }
3785        );
3786
3787        // Reaching batch size limit sends without debounce
3788        zeta.update(cx, |zeta, _cx| {
3789            for i in 0..70 {
3790                zeta.reject_prediction(
3791                    EditPredictionId(format!("batch-{}", i).into()),
3792                    EditPredictionRejectReason::Discarded,
3793                    false,
3794                );
3795            }
3796        });
3797
3798        // First MAX/2 items are sent immediately
3799        cx.run_until_parked();
3800        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3801        respond_tx.send(()).unwrap();
3802
3803        assert_eq!(reject_request.rejections.len(), 50);
3804        assert_eq!(reject_request.rejections[0].request_id, "batch-0");
3805        assert_eq!(reject_request.rejections[49].request_id, "batch-49");
3806
3807        // Remaining items are debounced with the next batch
3808        cx.executor().advance_clock(Duration::from_secs(15));
3809        cx.run_until_parked();
3810
3811        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3812        respond_tx.send(()).unwrap();
3813
3814        assert_eq!(reject_request.rejections.len(), 20);
3815        assert_eq!(reject_request.rejections[0].request_id, "batch-50");
3816        assert_eq!(reject_request.rejections[19].request_id, "batch-69");
3817
3818        // Request failure
3819        zeta.update(cx, |zeta, _cx| {
3820            zeta.reject_prediction(
3821                EditPredictionId("retry-1".into()),
3822                EditPredictionRejectReason::Discarded,
3823                false,
3824            );
3825        });
3826
3827        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3828        cx.run_until_parked();
3829
3830        let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
3831        assert_eq!(reject_request.rejections.len(), 1);
3832        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3833        // Simulate failure
3834        drop(_respond_tx);
3835
3836        // Add another rejection
3837        zeta.update(cx, |zeta, _cx| {
3838            zeta.reject_prediction(
3839                EditPredictionId("retry-2".into()),
3840                EditPredictionRejectReason::Discarded,
3841                false,
3842            );
3843        });
3844
3845        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3846        cx.run_until_parked();
3847
3848        // Retry should include both the failed item and the new one
3849        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3850        respond_tx.send(()).unwrap();
3851
3852        assert_eq!(reject_request.rejections.len(), 2);
3853        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3854        assert_eq!(reject_request.rejections[1].request_id, "retry-2");
3855    }
3856
3857    // Skipped until we start including diagnostics in prompt
3858    // #[gpui::test]
3859    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3860    //     let (zeta, mut req_rx) = init_test(cx);
3861    //     let fs = FakeFs::new(cx.executor());
3862    //     fs.insert_tree(
3863    //         "/root",
3864    //         json!({
3865    //             "foo.md": "Hello!\nBye"
3866    //         }),
3867    //     )
3868    //     .await;
3869    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3870
3871    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3872    //     let diagnostic = lsp::Diagnostic {
3873    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3874    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3875    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3876    //         ..Default::default()
3877    //     };
3878
3879    //     project.update(cx, |project, cx| {
3880    //         project.lsp_store().update(cx, |lsp_store, cx| {
3881    //             // Create some diagnostics
3882    //             lsp_store
3883    //                 .update_diagnostics(
3884    //                     LanguageServerId(0),
3885    //                     lsp::PublishDiagnosticsParams {
3886    //                         uri: path_to_buffer_uri.clone(),
3887    //                         diagnostics: vec![diagnostic],
3888    //                         version: None,
3889    //                     },
3890    //                     None,
3891    //                     language::DiagnosticSourceKind::Pushed,
3892    //                     &[],
3893    //                     cx,
3894    //                 )
3895    //                 .unwrap();
3896    //         });
3897    //     });
3898
3899    //     let buffer = project
3900    //         .update(cx, |project, cx| {
3901    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3902    //             project.open_buffer(path, cx)
3903    //         })
3904    //         .await
3905    //         .unwrap();
3906
3907    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3908    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3909
3910    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3911    //         zeta.request_prediction(&project, &buffer, position, cx)
3912    //     });
3913
3914    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3915
3916    //     assert_eq!(request.diagnostic_groups.len(), 1);
3917    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3918    //         .unwrap();
3919    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3920    //     assert_eq!(
3921    //         value,
3922    //         json!({
3923    //             "entries": [{
3924    //                 "range": {
3925    //                     "start": 8,
3926    //                     "end": 10
3927    //                 },
3928    //                 "diagnostic": {
3929    //                     "source": null,
3930    //                     "code": null,
3931    //                     "code_description": null,
3932    //                     "severity": 1,
3933    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3934    //                     "markdown": null,
3935    //                     "group_id": 0,
3936    //                     "is_primary": true,
3937    //                     "is_disk_based": false,
3938    //                     "is_unnecessary": false,
3939    //                     "source_kind": "Pushed",
3940    //                     "data": null,
3941    //                     "underline": true
3942    //                 }
3943    //             }],
3944    //             "primary_ix": 0
3945    //         })
3946    //     );
3947    // }
3948
3949    fn model_response(text: &str) -> open_ai::Response {
3950        open_ai::Response {
3951            id: Uuid::new_v4().to_string(),
3952            object: "response".into(),
3953            created: 0,
3954            model: "model".into(),
3955            choices: vec![open_ai::Choice {
3956                index: 0,
3957                message: open_ai::RequestMessage::Assistant {
3958                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3959                    tool_calls: vec![],
3960                },
3961                finish_reason: None,
3962            }],
3963            usage: Usage {
3964                prompt_tokens: 0,
3965                completion_tokens: 0,
3966                total_tokens: 0,
3967            },
3968        }
3969    }
3970
3971    fn prompt_from_request(request: &open_ai::Request) -> &str {
3972        assert_eq!(request.messages.len(), 1);
3973        let open_ai::RequestMessage::User {
3974            content: open_ai::MessageContent::Plain(content),
3975            ..
3976        } = &request.messages[0]
3977        else {
3978            panic!(
3979                "Request does not have single user message of type Plain. {:#?}",
3980                request
3981            );
3982        };
3983        content
3984    }
3985
3986    struct RequestChannels {
3987        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3988        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3989    }
3990
3991    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3992        cx.update(move |cx| {
3993            let settings_store = SettingsStore::test(cx);
3994            cx.set_global(settings_store);
3995            zlog::init_test();
3996
3997            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3998            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3999
4000            let http_client = FakeHttpClient::create({
4001                move |req| {
4002                    let uri = req.uri().path().to_string();
4003                    let mut body = req.into_body();
4004                    let predict_req_tx = predict_req_tx.clone();
4005                    let reject_req_tx = reject_req_tx.clone();
4006                    async move {
4007                        let resp = match uri.as_str() {
4008                            "/client/llm_tokens" => serde_json::to_string(&json!({
4009                                "token": "test"
4010                            }))
4011                            .unwrap(),
4012                            "/predict_edits/raw" => {
4013                                let mut buf = Vec::new();
4014                                body.read_to_end(&mut buf).await.ok();
4015                                let req = serde_json::from_slice(&buf).unwrap();
4016
4017                                let (res_tx, res_rx) = oneshot::channel();
4018                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
4019                                serde_json::to_string(&res_rx.await?).unwrap()
4020                            }
4021                            "/predict_edits/reject" => {
4022                                let mut buf = Vec::new();
4023                                body.read_to_end(&mut buf).await.ok();
4024                                let req = serde_json::from_slice(&buf).unwrap();
4025
4026                                let (res_tx, res_rx) = oneshot::channel();
4027                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
4028                                serde_json::to_string(&res_rx.await?).unwrap()
4029                            }
4030                            _ => {
4031                                panic!("Unexpected path: {}", uri)
4032                            }
4033                        };
4034
4035                        Ok(Response::builder().body(resp.into()).unwrap())
4036                    }
4037                }
4038            });
4039
4040            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
4041            client.cloud_client().set_credentials(1, "test".into());
4042
4043            language_model::init(client.clone(), cx);
4044
4045            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
4046            let zeta = Zeta::global(&client, &user_store, cx);
4047
4048            (
4049                zeta,
4050                RequestChannels {
4051                    predict: predict_req_rx,
4052                    reject: reject_req_rx,
4053                },
4054            )
4055        })
4056    }
4057}