zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  13use collections::{HashMap, HashSet};
  14use command_palette_hooks::CommandPaletteFilter;
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{
  17    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  18    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  19    SyntaxIndex, SyntaxIndexState,
  20};
  21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  22use futures::channel::mpsc::UnboundedReceiver;
  23use futures::channel::{mpsc, oneshot};
  24use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
  25use gpui::BackgroundExecutor;
  26use gpui::{
  27    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  28    http_client::{self, AsyncBody, Method},
  29    prelude::*,
  30};
  31use language::{
  32    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  33};
  34use language::{BufferSnapshot, OffsetRangeExt};
  35use language_model::{LlmApiToken, RefreshLlmTokenListener};
  36use open_ai::FunctionDefinition;
  37use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  38use release_channel::AppVersion;
  39use semver::Version;
  40use serde::de::DeserializeOwned;
  41use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  42use std::any::{Any as _, TypeId};
  43use std::collections::{VecDeque, hash_map};
  44use telemetry_events::EditPredictionRating;
  45use workspace::Workspace;
  46
  47use std::ops::Range;
  48use std::path::Path;
  49use std::rc::Rc;
  50use std::str::FromStr as _;
  51use std::sync::{Arc, LazyLock};
  52use std::time::{Duration, Instant};
  53use std::{env, mem};
  54use thiserror::Error;
  55use util::rel_path::RelPathBuf;
  56use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  57use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  58
  59pub mod assemble_excerpts;
  60mod license_detection;
  61mod onboarding_modal;
  62mod prediction;
  63mod provider;
  64mod rate_prediction_modal;
  65pub mod retrieval_search;
  66pub mod sweep_ai;
  67pub mod udiff;
  68mod xml_edits;
  69pub mod zeta1;
  70
  71#[cfg(test)]
  72mod zeta_tests;
  73
  74use crate::assemble_excerpts::assemble_excerpts;
  75use crate::license_detection::LicenseDetectionWatcher;
  76use crate::onboarding_modal::ZedPredictModal;
  77pub use crate::prediction::EditPrediction;
  78pub use crate::prediction::EditPredictionId;
  79pub use crate::prediction::EditPredictionInputs;
  80use crate::prediction::EditPredictionResult;
  81use crate::rate_prediction_modal::{
  82    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  83    ThumbsUpActivePrediction,
  84};
  85pub use crate::sweep_ai::SweepAi;
  86use crate::zeta1::request_prediction_with_zeta1;
  87pub use provider::ZetaEditPredictionProvider;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Opens the rate completions modal.
  95        RateCompletions,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106
 107pub struct SweepFeatureFlag;
 108
 109impl FeatureFlag for SweepFeatureFlag {
 110    const NAME: &str = "sweep-ai";
 111}
 112pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 113    max_bytes: 512,
 114    min_bytes: 128,
 115    target_before_cursor_over_total_bytes: 0.5,
 116};
 117
 118pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 119    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 120
 121pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 122    excerpt: DEFAULT_EXCERPT_OPTIONS,
 123};
 124
 125pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 126    EditPredictionContextOptions {
 127        use_imports: true,
 128        max_retrieved_declarations: 0,
 129        excerpt: DEFAULT_EXCERPT_OPTIONS,
 130        score: EditPredictionScoreOptions {
 131            omit_excerpt_overlaps: true,
 132        },
 133    };
 134
 135pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 136    context: DEFAULT_CONTEXT_OPTIONS,
 137    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 138    max_diagnostic_bytes: 2048,
 139    prompt_format: PromptFormat::DEFAULT,
 140    file_indexing_parallelism: 1,
 141    buffer_change_grouping_interval: Duration::from_secs(1),
 142};
 143
 144static USE_OLLAMA: LazyLock<bool> =
 145    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 146static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 147    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 148        "qwen3-coder:30b".to_string()
 149    } else {
 150        "yqvev8r3".to_string()
 151    })
 152});
 153static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 154    match env::var("ZED_ZETA2_MODEL").as_deref() {
 155        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 156        Ok(model) => model,
 157        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 158        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 159    }
 160    .to_string()
 161});
 162static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 163    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 164        if *USE_OLLAMA {
 165            Some("http://localhost:11434/v1/chat/completions".into())
 166        } else {
 167            None
 168        }
 169    })
 170});
 171
 172pub struct Zeta2FeatureFlag;
 173
 174impl FeatureFlag for Zeta2FeatureFlag {
 175    const NAME: &'static str = "zeta2";
 176
 177    fn enabled_for_staff() -> bool {
 178        true
 179    }
 180}
 181
 182#[derive(Clone)]
 183struct ZetaGlobal(Entity<Zeta>);
 184
 185impl Global for ZetaGlobal {}
 186
 187pub struct Zeta {
 188    client: Arc<Client>,
 189    user_store: Entity<UserStore>,
 190    llm_token: LlmApiToken,
 191    _llm_token_subscription: Subscription,
 192    projects: HashMap<EntityId, ZetaProject>,
 193    options: ZetaOptions,
 194    update_required: bool,
 195    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 196    #[cfg(feature = "eval-support")]
 197    eval_cache: Option<Arc<dyn EvalCache>>,
 198    edit_prediction_model: ZetaEditPredictionModel,
 199    pub sweep_ai: SweepAi,
 200    data_collection_choice: DataCollectionChoice,
 201    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 202    shown_predictions: VecDeque<EditPrediction>,
 203    rated_predictions: HashSet<EditPredictionId>,
 204}
 205
 206#[derive(Copy, Clone, Default, PartialEq, Eq)]
 207pub enum ZetaEditPredictionModel {
 208    #[default]
 209    Zeta1,
 210    Zeta2,
 211    Sweep,
 212}
 213
 214#[derive(Debug, Clone, PartialEq)]
 215pub struct ZetaOptions {
 216    pub context: ContextMode,
 217    pub max_prompt_bytes: usize,
 218    pub max_diagnostic_bytes: usize,
 219    pub prompt_format: predict_edits_v3::PromptFormat,
 220    pub file_indexing_parallelism: usize,
 221    pub buffer_change_grouping_interval: Duration,
 222}
 223
 224#[derive(Debug, Clone, PartialEq)]
 225pub enum ContextMode {
 226    Agentic(AgenticContextOptions),
 227    Syntax(EditPredictionContextOptions),
 228}
 229
 230#[derive(Debug, Clone, PartialEq)]
 231pub struct AgenticContextOptions {
 232    pub excerpt: EditPredictionExcerptOptions,
 233}
 234
 235impl ContextMode {
 236    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 237        match self {
 238            ContextMode::Agentic(options) => &options.excerpt,
 239            ContextMode::Syntax(options) => &options.excerpt,
 240        }
 241    }
 242}
 243
 244#[derive(Debug)]
 245pub enum ZetaDebugInfo {
 246    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 247    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 248    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 249    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 250    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 251}
 252
 253#[derive(Debug)]
 254pub struct ZetaContextRetrievalStartedDebugInfo {
 255    pub project: Entity<Project>,
 256    pub timestamp: Instant,
 257    pub search_prompt: String,
 258}
 259
 260#[derive(Debug)]
 261pub struct ZetaContextRetrievalDebugInfo {
 262    pub project: Entity<Project>,
 263    pub timestamp: Instant,
 264}
 265
 266#[derive(Debug)]
 267pub struct ZetaEditPredictionDebugInfo {
 268    pub inputs: EditPredictionInputs,
 269    pub retrieval_time: Duration,
 270    pub buffer: WeakEntity<Buffer>,
 271    pub position: language::Anchor,
 272    pub local_prompt: Result<String, String>,
 273    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 274}
 275
 276#[derive(Debug)]
 277pub struct ZetaSearchQueryDebugInfo {
 278    pub project: Entity<Project>,
 279    pub timestamp: Instant,
 280    pub search_queries: Vec<SearchToolQuery>,
 281}
 282
 283pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 284
 285struct ZetaProject {
 286    syntax_index: Option<Entity<SyntaxIndex>>,
 287    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 288    last_event: Option<LastEvent>,
 289    recent_paths: VecDeque<ProjectPath>,
 290    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 291    current_prediction: Option<CurrentEditPrediction>,
 292    next_pending_prediction_id: usize,
 293    pending_predictions: ArrayVec<PendingPrediction, 2>,
 294    last_prediction_refresh: Option<(EntityId, Instant)>,
 295    cancelled_predictions: HashSet<usize>,
 296    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 297    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 298    refresh_context_debounce_task: Option<Task<Option<()>>>,
 299    refresh_context_timestamp: Option<Instant>,
 300    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 301    _subscription: gpui::Subscription,
 302}
 303
 304impl ZetaProject {
 305    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 306        self.events
 307            .iter()
 308            .cloned()
 309            .chain(
 310                self.last_event
 311                    .as_ref()
 312                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 313            )
 314            .collect()
 315    }
 316
 317    fn cancel_pending_prediction(
 318        &mut self,
 319        pending_prediction: PendingPrediction,
 320        cx: &mut Context<Zeta>,
 321    ) {
 322        self.cancelled_predictions.insert(pending_prediction.id);
 323
 324        cx.spawn(async move |this, cx| {
 325            let Some(prediction_id) = pending_prediction.task.await else {
 326                return;
 327            };
 328
 329            this.update(cx, |this, _cx| {
 330                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 331            })
 332            .ok();
 333        })
 334        .detach()
 335    }
 336}
 337
 338#[derive(Debug, Clone)]
 339struct CurrentEditPrediction {
 340    pub requested_by: PredictionRequestedBy,
 341    pub prediction: EditPrediction,
 342    pub was_shown: bool,
 343}
 344
 345impl CurrentEditPrediction {
 346    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 347        let Some(new_edits) = self
 348            .prediction
 349            .interpolate(&self.prediction.buffer.read(cx))
 350        else {
 351            return false;
 352        };
 353
 354        if self.prediction.buffer != old_prediction.prediction.buffer {
 355            return true;
 356        }
 357
 358        let Some(old_edits) = old_prediction
 359            .prediction
 360            .interpolate(&old_prediction.prediction.buffer.read(cx))
 361        else {
 362            return true;
 363        };
 364
 365        let requested_by_buffer_id = self.requested_by.buffer_id();
 366
 367        // This reduces the occurrence of UI thrash from replacing edits
 368        //
 369        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 370        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 371            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 372            && old_edits.len() == 1
 373            && new_edits.len() == 1
 374        {
 375            let (old_range, old_text) = &old_edits[0];
 376            let (new_range, new_text) = &new_edits[0];
 377            new_range == old_range && new_text.starts_with(old_text.as_ref())
 378        } else {
 379            true
 380        }
 381    }
 382}
 383
 384#[derive(Debug, Clone)]
 385enum PredictionRequestedBy {
 386    DiagnosticsUpdate,
 387    Buffer(EntityId),
 388}
 389
 390impl PredictionRequestedBy {
 391    pub fn buffer_id(&self) -> Option<EntityId> {
 392        match self {
 393            PredictionRequestedBy::DiagnosticsUpdate => None,
 394            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 395        }
 396    }
 397}
 398
 399#[derive(Debug)]
 400struct PendingPrediction {
 401    id: usize,
 402    task: Task<Option<EditPredictionId>>,
 403}
 404
 405/// A prediction from the perspective of a buffer.
 406#[derive(Debug)]
 407enum BufferEditPrediction<'a> {
 408    Local { prediction: &'a EditPrediction },
 409    Jump { prediction: &'a EditPrediction },
 410}
 411
 412#[cfg(test)]
 413impl std::ops::Deref for BufferEditPrediction<'_> {
 414    type Target = EditPrediction;
 415
 416    fn deref(&self) -> &Self::Target {
 417        match self {
 418            BufferEditPrediction::Local { prediction } => prediction,
 419            BufferEditPrediction::Jump { prediction } => prediction,
 420        }
 421    }
 422}
 423
 424struct RegisteredBuffer {
 425    snapshot: BufferSnapshot,
 426    _subscriptions: [gpui::Subscription; 2],
 427}
 428
 429struct LastEvent {
 430    old_snapshot: BufferSnapshot,
 431    new_snapshot: BufferSnapshot,
 432    end_edit_anchor: Option<Anchor>,
 433}
 434
 435impl LastEvent {
 436    pub fn finalize(
 437        &self,
 438        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 439        cx: &App,
 440    ) -> Option<Arc<predict_edits_v3::Event>> {
 441        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 442        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 443
 444        let file = self.new_snapshot.file();
 445        let old_file = self.old_snapshot.file();
 446
 447        let in_open_source_repo = [file, old_file].iter().all(|file| {
 448            file.is_some_and(|file| {
 449                license_detection_watchers
 450                    .get(&file.worktree_id(cx))
 451                    .is_some_and(|watcher| watcher.is_project_open_source())
 452            })
 453        });
 454
 455        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 456
 457        if path == old_path && diff.is_empty() {
 458            None
 459        } else {
 460            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 461                old_path,
 462                path,
 463                diff,
 464                in_open_source_repo,
 465                // TODO: Actually detect if this edit was predicted or not
 466                predicted: false,
 467            }))
 468        }
 469    }
 470}
 471
 472fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 473    if let Some(file) = snapshot.file() {
 474        file.full_path(cx).into()
 475    } else {
 476        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 477    }
 478}
 479
 480impl Zeta {
 481    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 482        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 483    }
 484
 485    pub fn global(
 486        client: &Arc<Client>,
 487        user_store: &Entity<UserStore>,
 488        cx: &mut App,
 489    ) -> Entity<Self> {
 490        cx.try_global::<ZetaGlobal>()
 491            .map(|global| global.0.clone())
 492            .unwrap_or_else(|| {
 493                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 494                cx.set_global(ZetaGlobal(zeta.clone()));
 495                zeta
 496            })
 497    }
 498
 499    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 500        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 501        let data_collection_choice = Self::load_data_collection_choice();
 502
 503        let llm_token = LlmApiToken::default();
 504
 505        let (reject_tx, reject_rx) = mpsc::unbounded();
 506        cx.background_spawn({
 507            let client = client.clone();
 508            let llm_token = llm_token.clone();
 509            let app_version = AppVersion::global(cx);
 510            let background_executor = cx.background_executor().clone();
 511            async move {
 512                Self::handle_rejected_predictions(
 513                    reject_rx,
 514                    client,
 515                    llm_token,
 516                    app_version,
 517                    background_executor,
 518                )
 519                .await
 520            }
 521        })
 522        .detach();
 523
 524        Self {
 525            projects: HashMap::default(),
 526            client,
 527            user_store,
 528            options: DEFAULT_OPTIONS,
 529            llm_token,
 530            _llm_token_subscription: cx.subscribe(
 531                &refresh_llm_token_listener,
 532                |this, _listener, _event, cx| {
 533                    let client = this.client.clone();
 534                    let llm_token = this.llm_token.clone();
 535                    cx.spawn(async move |_this, _cx| {
 536                        llm_token.refresh(&client).await?;
 537                        anyhow::Ok(())
 538                    })
 539                    .detach_and_log_err(cx);
 540                },
 541            ),
 542            update_required: false,
 543            debug_tx: None,
 544            #[cfg(feature = "eval-support")]
 545            eval_cache: None,
 546            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 547            sweep_ai: SweepAi::new(cx),
 548            data_collection_choice,
 549            reject_predictions_tx: reject_tx,
 550            rated_predictions: Default::default(),
 551            shown_predictions: Default::default(),
 552        }
 553    }
 554
 555    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 556        self.edit_prediction_model = model;
 557    }
 558
 559    pub fn has_sweep_api_token(&self) -> bool {
 560        self.sweep_ai
 561            .api_token
 562            .clone()
 563            .now_or_never()
 564            .flatten()
 565            .is_some()
 566    }
 567
 568    #[cfg(feature = "eval-support")]
 569    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 570        self.eval_cache = Some(cache);
 571    }
 572
 573    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 574        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 575        self.debug_tx = Some(debug_watch_tx);
 576        debug_watch_rx
 577    }
 578
 579    pub fn options(&self) -> &ZetaOptions {
 580        &self.options
 581    }
 582
 583    pub fn set_options(&mut self, options: ZetaOptions) {
 584        self.options = options;
 585    }
 586
 587    pub fn clear_history(&mut self) {
 588        for zeta_project in self.projects.values_mut() {
 589            zeta_project.events.clear();
 590        }
 591    }
 592
 593    pub fn context_for_project(
 594        &self,
 595        project: &Entity<Project>,
 596    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 597        self.projects
 598            .get(&project.entity_id())
 599            .and_then(|project| {
 600                Some(
 601                    project
 602                        .context
 603                        .as_ref()?
 604                        .iter()
 605                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 606                )
 607            })
 608            .into_iter()
 609            .flatten()
 610    }
 611
 612    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 613        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 614            self.user_store.read(cx).edit_prediction_usage()
 615        } else {
 616            None
 617        }
 618    }
 619
 620    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 621        self.get_or_init_zeta_project(project, cx);
 622    }
 623
 624    pub fn register_buffer(
 625        &mut self,
 626        buffer: &Entity<Buffer>,
 627        project: &Entity<Project>,
 628        cx: &mut Context<Self>,
 629    ) {
 630        let zeta_project = self.get_or_init_zeta_project(project, cx);
 631        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 632    }
 633
 634    fn get_or_init_zeta_project(
 635        &mut self,
 636        project: &Entity<Project>,
 637        cx: &mut Context<Self>,
 638    ) -> &mut ZetaProject {
 639        self.projects
 640            .entry(project.entity_id())
 641            .or_insert_with(|| ZetaProject {
 642                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 643                    Some(cx.new(|cx| {
 644                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 645                    }))
 646                } else {
 647                    None
 648                },
 649                events: VecDeque::new(),
 650                last_event: None,
 651                recent_paths: VecDeque::new(),
 652                registered_buffers: HashMap::default(),
 653                current_prediction: None,
 654                cancelled_predictions: HashSet::default(),
 655                pending_predictions: ArrayVec::new(),
 656                next_pending_prediction_id: 0,
 657                last_prediction_refresh: None,
 658                context: None,
 659                refresh_context_task: None,
 660                refresh_context_debounce_task: None,
 661                refresh_context_timestamp: None,
 662                license_detection_watchers: HashMap::default(),
 663                _subscription: cx.subscribe(&project, Self::handle_project_event),
 664            })
 665    }
 666
 667    fn handle_project_event(
 668        &mut self,
 669        project: Entity<Project>,
 670        event: &project::Event,
 671        cx: &mut Context<Self>,
 672    ) {
 673        // TODO [zeta2] init with recent paths
 674        match event {
 675            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 676                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 677                    return;
 678                };
 679                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 680                if let Some(path) = path {
 681                    if let Some(ix) = zeta_project
 682                        .recent_paths
 683                        .iter()
 684                        .position(|probe| probe == &path)
 685                    {
 686                        zeta_project.recent_paths.remove(ix);
 687                    }
 688                    zeta_project.recent_paths.push_front(path);
 689                }
 690            }
 691            project::Event::DiagnosticsUpdated { .. } => {
 692                if cx.has_flag::<Zeta2FeatureFlag>() {
 693                    self.refresh_prediction_from_diagnostics(project, cx);
 694                }
 695            }
 696            _ => (),
 697        }
 698    }
 699
 700    fn register_buffer_impl<'a>(
 701        zeta_project: &'a mut ZetaProject,
 702        buffer: &Entity<Buffer>,
 703        project: &Entity<Project>,
 704        cx: &mut Context<Self>,
 705    ) -> &'a mut RegisteredBuffer {
 706        let buffer_id = buffer.entity_id();
 707
 708        if let Some(file) = buffer.read(cx).file() {
 709            let worktree_id = file.worktree_id(cx);
 710            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 711                zeta_project
 712                    .license_detection_watchers
 713                    .entry(worktree_id)
 714                    .or_insert_with(|| {
 715                        let project_entity_id = project.entity_id();
 716                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 717                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 718                            else {
 719                                return;
 720                            };
 721                            zeta_project.license_detection_watchers.remove(&worktree_id);
 722                        })
 723                        .detach();
 724                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 725                    });
 726            }
 727        }
 728
 729        match zeta_project.registered_buffers.entry(buffer_id) {
 730            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 731            hash_map::Entry::Vacant(entry) => {
 732                let snapshot = buffer.read(cx).snapshot();
 733                let project_entity_id = project.entity_id();
 734                entry.insert(RegisteredBuffer {
 735                    snapshot,
 736                    _subscriptions: [
 737                        cx.subscribe(buffer, {
 738                            let project = project.downgrade();
 739                            move |this, buffer, event, cx| {
 740                                if let language::BufferEvent::Edited = event
 741                                    && let Some(project) = project.upgrade()
 742                                {
 743                                    this.report_changes_for_buffer(&buffer, &project, cx);
 744                                }
 745                            }
 746                        }),
 747                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 748                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 749                            else {
 750                                return;
 751                            };
 752                            zeta_project.registered_buffers.remove(&buffer_id);
 753                        }),
 754                    ],
 755                })
 756            }
 757        }
 758    }
 759
 760    fn report_changes_for_buffer(
 761        &mut self,
 762        buffer: &Entity<Buffer>,
 763        project: &Entity<Project>,
 764        cx: &mut Context<Self>,
 765    ) {
 766        let project_state = self.get_or_init_zeta_project(project, cx);
 767        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 768
 769        let new_snapshot = buffer.read(cx).snapshot();
 770        if new_snapshot.version == registered_buffer.snapshot.version {
 771            return;
 772        }
 773
 774        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 775        let end_edit_anchor = new_snapshot
 776            .anchored_edits_since::<Point>(&old_snapshot.version)
 777            .last()
 778            .map(|(_, range)| range.end);
 779        let events = &mut project_state.events;
 780
 781        if let Some(LastEvent {
 782            new_snapshot: last_new_snapshot,
 783            end_edit_anchor: last_end_edit_anchor,
 784            ..
 785        }) = project_state.last_event.as_mut()
 786        {
 787            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 788                == last_new_snapshot.remote_id()
 789                && old_snapshot.version == last_new_snapshot.version;
 790
 791            let should_coalesce = is_next_snapshot_of_same_buffer
 792                && end_edit_anchor
 793                    .as_ref()
 794                    .zip(last_end_edit_anchor.as_ref())
 795                    .is_some_and(|(a, b)| {
 796                        let a = a.to_point(&new_snapshot);
 797                        let b = b.to_point(&new_snapshot);
 798                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 799                    });
 800
 801            if should_coalesce {
 802                *last_end_edit_anchor = end_edit_anchor;
 803                *last_new_snapshot = new_snapshot;
 804                return;
 805            }
 806        }
 807
 808        if events.len() + 1 >= EVENT_COUNT_MAX {
 809            events.pop_front();
 810        }
 811
 812        if let Some(event) = project_state.last_event.take() {
 813            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 814        }
 815
 816        project_state.last_event = Some(LastEvent {
 817            old_snapshot,
 818            new_snapshot,
 819            end_edit_anchor,
 820        });
 821    }
 822
 823    fn current_prediction_for_buffer(
 824        &self,
 825        buffer: &Entity<Buffer>,
 826        project: &Entity<Project>,
 827        cx: &App,
 828    ) -> Option<BufferEditPrediction<'_>> {
 829        let project_state = self.projects.get(&project.entity_id())?;
 830
 831        let CurrentEditPrediction {
 832            requested_by,
 833            prediction,
 834            ..
 835        } = project_state.current_prediction.as_ref()?;
 836
 837        if prediction.targets_buffer(buffer.read(cx)) {
 838            Some(BufferEditPrediction::Local { prediction })
 839        } else {
 840            let show_jump = match requested_by {
 841                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 842                    requested_by_buffer_id == &buffer.entity_id()
 843                }
 844                PredictionRequestedBy::DiagnosticsUpdate => true,
 845            };
 846
 847            if show_jump {
 848                Some(BufferEditPrediction::Jump { prediction })
 849            } else {
 850                None
 851            }
 852        }
 853    }
 854
 855    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 856        match self.edit_prediction_model {
 857            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 858            ZetaEditPredictionModel::Sweep => return,
 859        }
 860
 861        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 862            return;
 863        };
 864
 865        let Some(prediction) = project_state.current_prediction.take() else {
 866            return;
 867        };
 868        let request_id = prediction.prediction.id.to_string();
 869        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 870            project_state.cancel_pending_prediction(pending_prediction, cx);
 871        }
 872
 873        let client = self.client.clone();
 874        let llm_token = self.llm_token.clone();
 875        let app_version = AppVersion::global(cx);
 876        cx.spawn(async move |this, cx| {
 877            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 878                http_client::Url::parse(&predict_edits_url)?
 879            } else {
 880                client
 881                    .http_client()
 882                    .build_zed_llm_url("/predict_edits/accept", &[])?
 883            };
 884
 885            let response = cx
 886                .background_spawn(Self::send_api_request::<()>(
 887                    move |builder| {
 888                        let req = builder.uri(url.as_ref()).body(
 889                            serde_json::to_string(&AcceptEditPredictionBody {
 890                                request_id: request_id.clone(),
 891                            })?
 892                            .into(),
 893                        );
 894                        Ok(req?)
 895                    },
 896                    client,
 897                    llm_token,
 898                    app_version,
 899                ))
 900                .await;
 901
 902            Self::handle_api_response(&this, response, cx)?;
 903            anyhow::Ok(())
 904        })
 905        .detach_and_log_err(cx);
 906    }
 907
 908    async fn handle_rejected_predictions(
 909        rx: UnboundedReceiver<EditPredictionRejection>,
 910        client: Arc<Client>,
 911        llm_token: LlmApiToken,
 912        app_version: Version,
 913        background_executor: BackgroundExecutor,
 914    ) {
 915        let mut rx = std::pin::pin!(rx.peekable());
 916        let mut batched = Vec::new();
 917
 918        while let Some(rejection) = rx.next().await {
 919            batched.push(rejection);
 920
 921            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
 922                select_biased! {
 923                    next = rx.as_mut().peek().fuse() => {
 924                        if next.is_some() {
 925                            continue;
 926                        }
 927                    }
 928                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
 929                }
 930            }
 931
 932            let url = client
 933                .http_client()
 934                .build_zed_llm_url("/predict_edits/reject", &[])
 935                .unwrap();
 936
 937            let flush_count = batched
 938                .len()
 939                // in case items have accumulated after failure
 940                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
 941            let start = batched.len() - flush_count;
 942
 943            let body = RejectEditPredictionsBodyRef {
 944                rejections: &batched[start..],
 945            };
 946
 947            let result = Self::send_api_request::<()>(
 948                |builder| {
 949                    let req = builder
 950                        .uri(url.as_ref())
 951                        .body(serde_json::to_string(&body)?.into());
 952                    anyhow::Ok(req?)
 953                },
 954                client.clone(),
 955                llm_token.clone(),
 956                app_version.clone(),
 957            )
 958            .await;
 959
 960            if result.log_err().is_some() {
 961                batched.drain(start..);
 962            }
 963        }
 964    }
 965
 966    fn reject_current_prediction(
 967        &mut self,
 968        reason: EditPredictionRejectReason,
 969        project: &Entity<Project>,
 970    ) {
 971        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 972            project_state.pending_predictions.clear();
 973            if let Some(prediction) = project_state.current_prediction.take() {
 974                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
 975            }
 976        };
 977    }
 978
 979    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 980        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 981            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 982                if !current_prediction.was_shown {
 983                    current_prediction.was_shown = true;
 984                    self.shown_predictions
 985                        .push_front(current_prediction.prediction.clone());
 986                    if self.shown_predictions.len() > 50 {
 987                        let completion = self.shown_predictions.pop_back().unwrap();
 988                        self.rated_predictions.remove(&completion.id);
 989                    }
 990                }
 991            }
 992        }
 993    }
 994
 995    fn reject_prediction(
 996        &mut self,
 997        prediction_id: EditPredictionId,
 998        reason: EditPredictionRejectReason,
 999        was_shown: bool,
1000    ) {
1001        match self.edit_prediction_model {
1002            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
1003            ZetaEditPredictionModel::Sweep => return,
1004        }
1005
1006        self.reject_predictions_tx
1007            .unbounded_send(EditPredictionRejection {
1008                request_id: prediction_id.to_string(),
1009                reason,
1010                was_shown,
1011            })
1012            .log_err();
1013    }
1014
1015    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1016        self.projects
1017            .get(&project.entity_id())
1018            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1019    }
1020
1021    pub fn refresh_prediction_from_buffer(
1022        &mut self,
1023        project: Entity<Project>,
1024        buffer: Entity<Buffer>,
1025        position: language::Anchor,
1026        cx: &mut Context<Self>,
1027    ) {
1028        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1029            let Some(request_task) = this
1030                .update(cx, |this, cx| {
1031                    this.request_prediction(
1032                        &project,
1033                        &buffer,
1034                        position,
1035                        PredictEditsRequestTrigger::Other,
1036                        cx,
1037                    )
1038                })
1039                .log_err()
1040            else {
1041                return Task::ready(anyhow::Ok(None));
1042            };
1043
1044            cx.spawn(async move |_cx| {
1045                request_task.await.map(|prediction_result| {
1046                    prediction_result.map(|prediction_result| {
1047                        (
1048                            prediction_result,
1049                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1050                        )
1051                    })
1052                })
1053            })
1054        })
1055    }
1056
1057    pub fn refresh_prediction_from_diagnostics(
1058        &mut self,
1059        project: Entity<Project>,
1060        cx: &mut Context<Self>,
1061    ) {
1062        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1063            return;
1064        };
1065
1066        // Prefer predictions from buffer
1067        if zeta_project.current_prediction.is_some() {
1068            return;
1069        };
1070
1071        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1072            let Some(open_buffer_task) = project
1073                .update(cx, |project, cx| {
1074                    project
1075                        .active_entry()
1076                        .and_then(|entry| project.path_for_entry(entry, cx))
1077                        .map(|path| project.open_buffer(path, cx))
1078                })
1079                .log_err()
1080                .flatten()
1081            else {
1082                return Task::ready(anyhow::Ok(None));
1083            };
1084
1085            cx.spawn(async move |cx| {
1086                let active_buffer = open_buffer_task.await?;
1087                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1088
1089                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1090                    active_buffer,
1091                    &snapshot,
1092                    Default::default(),
1093                    Default::default(),
1094                    &project,
1095                    cx,
1096                )
1097                .await?
1098                else {
1099                    return anyhow::Ok(None);
1100                };
1101
1102                let Some(prediction_result) = this
1103                    .update(cx, |this, cx| {
1104                        this.request_prediction(
1105                            &project,
1106                            &jump_buffer,
1107                            jump_position,
1108                            PredictEditsRequestTrigger::Diagnostics,
1109                            cx,
1110                        )
1111                    })?
1112                    .await?
1113                else {
1114                    return anyhow::Ok(None);
1115                };
1116
1117                this.update(cx, |this, cx| {
1118                    Some((
1119                        if this
1120                            .get_or_init_zeta_project(&project, cx)
1121                            .current_prediction
1122                            .is_none()
1123                        {
1124                            prediction_result
1125                        } else {
1126                            EditPredictionResult {
1127                                id: prediction_result.id,
1128                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1129                            }
1130                        },
1131                        PredictionRequestedBy::DiagnosticsUpdate,
1132                    ))
1133                })
1134            })
1135        });
1136    }
1137
1138    #[cfg(not(test))]
1139    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1140    #[cfg(test)]
1141    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1142
1143    fn queue_prediction_refresh(
1144        &mut self,
1145        project: Entity<Project>,
1146        throttle_entity: EntityId,
1147        cx: &mut Context<Self>,
1148        do_refresh: impl FnOnce(
1149            WeakEntity<Self>,
1150            &mut AsyncApp,
1151        )
1152            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1153        + 'static,
1154    ) {
1155        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1156        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1157        zeta_project.next_pending_prediction_id += 1;
1158        let last_request = zeta_project.last_prediction_refresh;
1159
1160        let task = cx.spawn(async move |this, cx| {
1161            if let Some((last_entity, last_timestamp)) = last_request
1162                && throttle_entity == last_entity
1163                && let Some(timeout) =
1164                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1165            {
1166                cx.background_executor().timer(timeout).await;
1167            }
1168
1169            // If this task was cancelled before the throttle timeout expired,
1170            // do not perform a request.
1171            let mut is_cancelled = true;
1172            this.update(cx, |this, cx| {
1173                let project_state = this.get_or_init_zeta_project(&project, cx);
1174                if !project_state
1175                    .cancelled_predictions
1176                    .remove(&pending_prediction_id)
1177                {
1178                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1179                    is_cancelled = false;
1180                }
1181            })
1182            .ok();
1183            if is_cancelled {
1184                return None;
1185            }
1186
1187            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1188            let new_prediction_id = new_prediction_result
1189                .as_ref()
1190                .map(|(prediction, _)| prediction.id.clone());
1191
1192            // When a prediction completes, remove it from the pending list, and cancel
1193            // any pending predictions that were enqueued before it.
1194            this.update(cx, |this, cx| {
1195                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1196
1197                let is_cancelled = zeta_project
1198                    .cancelled_predictions
1199                    .remove(&pending_prediction_id);
1200
1201                let new_current_prediction = if !is_cancelled
1202                    && let Some((prediction_result, requested_by)) = new_prediction_result
1203                {
1204                    match prediction_result.prediction {
1205                        Ok(prediction) => {
1206                            let new_prediction = CurrentEditPrediction {
1207                                requested_by,
1208                                prediction,
1209                                was_shown: false,
1210                            };
1211
1212                            if let Some(current_prediction) =
1213                                zeta_project.current_prediction.as_ref()
1214                            {
1215                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1216                                {
1217                                    this.reject_current_prediction(
1218                                        EditPredictionRejectReason::Replaced,
1219                                        &project,
1220                                    );
1221
1222                                    Some(new_prediction)
1223                                } else {
1224                                    this.reject_prediction(
1225                                        new_prediction.prediction.id,
1226                                        EditPredictionRejectReason::CurrentPreferred,
1227                                        false,
1228                                    );
1229                                    None
1230                                }
1231                            } else {
1232                                Some(new_prediction)
1233                            }
1234                        }
1235                        Err(reject_reason) => {
1236                            this.reject_prediction(prediction_result.id, reject_reason, false);
1237                            None
1238                        }
1239                    }
1240                } else {
1241                    None
1242                };
1243
1244                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1245
1246                if let Some(new_prediction) = new_current_prediction {
1247                    zeta_project.current_prediction = Some(new_prediction);
1248                }
1249
1250                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1251                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1252                    if pending_prediction.id == pending_prediction_id {
1253                        pending_predictions.remove(ix);
1254                        for pending_prediction in pending_predictions.drain(0..ix) {
1255                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1256                        }
1257                        break;
1258                    }
1259                }
1260                this.get_or_init_zeta_project(&project, cx)
1261                    .pending_predictions = pending_predictions;
1262                cx.notify();
1263            })
1264            .ok();
1265
1266            new_prediction_id
1267        });
1268
1269        if zeta_project.pending_predictions.len() <= 1 {
1270            zeta_project.pending_predictions.push(PendingPrediction {
1271                id: pending_prediction_id,
1272                task,
1273            });
1274        } else if zeta_project.pending_predictions.len() == 2 {
1275            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1276            zeta_project.pending_predictions.push(PendingPrediction {
1277                id: pending_prediction_id,
1278                task,
1279            });
1280            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1281        }
1282    }
1283
1284    pub fn request_prediction(
1285        &mut self,
1286        project: &Entity<Project>,
1287        active_buffer: &Entity<Buffer>,
1288        position: language::Anchor,
1289        trigger: PredictEditsRequestTrigger,
1290        cx: &mut Context<Self>,
1291    ) -> Task<Result<Option<EditPredictionResult>>> {
1292        self.request_prediction_internal(
1293            project.clone(),
1294            active_buffer.clone(),
1295            position,
1296            trigger,
1297            cx.has_flag::<Zeta2FeatureFlag>(),
1298            cx,
1299        )
1300    }
1301
1302    fn request_prediction_internal(
1303        &mut self,
1304        project: Entity<Project>,
1305        active_buffer: Entity<Buffer>,
1306        position: language::Anchor,
1307        trigger: PredictEditsRequestTrigger,
1308        allow_jump: bool,
1309        cx: &mut Context<Self>,
1310    ) -> Task<Result<Option<EditPredictionResult>>> {
1311        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1312
1313        self.get_or_init_zeta_project(&project, cx);
1314        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1315        let events = zeta_project.events(cx);
1316        let has_events = !events.is_empty();
1317
1318        let snapshot = active_buffer.read(cx).snapshot();
1319        let cursor_point = position.to_point(&snapshot);
1320        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1321        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1322        let diagnostic_search_range =
1323            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1324
1325        let task = match self.edit_prediction_model {
1326            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1327                self,
1328                &project,
1329                &active_buffer,
1330                snapshot.clone(),
1331                position,
1332                events,
1333                trigger,
1334                cx,
1335            ),
1336            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1337                &project,
1338                &active_buffer,
1339                snapshot.clone(),
1340                position,
1341                events,
1342                trigger,
1343                cx,
1344            ),
1345            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1346                &project,
1347                &active_buffer,
1348                snapshot.clone(),
1349                position,
1350                events,
1351                &zeta_project.recent_paths,
1352                diagnostic_search_range.clone(),
1353                cx,
1354            ),
1355        };
1356
1357        cx.spawn(async move |this, cx| {
1358            let prediction = task.await?;
1359
1360            if prediction.is_none() && allow_jump {
1361                let cursor_point = position.to_point(&snapshot);
1362                if has_events
1363                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1364                        active_buffer.clone(),
1365                        &snapshot,
1366                        diagnostic_search_range,
1367                        cursor_point,
1368                        &project,
1369                        cx,
1370                    )
1371                    .await?
1372                {
1373                    return this
1374                        .update(cx, |this, cx| {
1375                            this.request_prediction_internal(
1376                                project,
1377                                jump_buffer,
1378                                jump_position,
1379                                trigger,
1380                                false,
1381                                cx,
1382                            )
1383                        })?
1384                        .await;
1385                }
1386
1387                return anyhow::Ok(None);
1388            }
1389
1390            Ok(prediction)
1391        })
1392    }
1393
1394    async fn next_diagnostic_location(
1395        active_buffer: Entity<Buffer>,
1396        active_buffer_snapshot: &BufferSnapshot,
1397        active_buffer_diagnostic_search_range: Range<Point>,
1398        active_buffer_cursor_point: Point,
1399        project: &Entity<Project>,
1400        cx: &mut AsyncApp,
1401    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1402        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1403        let mut jump_location = active_buffer_snapshot
1404            .diagnostic_groups(None)
1405            .into_iter()
1406            .filter_map(|(_, group)| {
1407                let range = &group.entries[group.primary_ix]
1408                    .range
1409                    .to_point(&active_buffer_snapshot);
1410                if range.overlaps(&active_buffer_diagnostic_search_range) {
1411                    None
1412                } else {
1413                    Some(range.start)
1414                }
1415            })
1416            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1417            .map(|position| {
1418                (
1419                    active_buffer.clone(),
1420                    active_buffer_snapshot.anchor_before(position),
1421                )
1422            });
1423
1424        if jump_location.is_none() {
1425            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1426                let file = buffer.file()?;
1427
1428                Some(ProjectPath {
1429                    worktree_id: file.worktree_id(cx),
1430                    path: file.path().clone(),
1431                })
1432            })?;
1433
1434            let buffer_task = project.update(cx, |project, cx| {
1435                let (path, _, _) = project
1436                    .diagnostic_summaries(false, cx)
1437                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1438                    .max_by_key(|(path, _, _)| {
1439                        // find the buffer with errors that shares most parent directories
1440                        path.path
1441                            .components()
1442                            .zip(
1443                                active_buffer_path
1444                                    .as_ref()
1445                                    .map(|p| p.path.components())
1446                                    .unwrap_or_default(),
1447                            )
1448                            .take_while(|(a, b)| a == b)
1449                            .count()
1450                    })?;
1451
1452                Some(project.open_buffer(path, cx))
1453            })?;
1454
1455            if let Some(buffer_task) = buffer_task {
1456                let closest_buffer = buffer_task.await?;
1457
1458                jump_location = closest_buffer
1459                    .read_with(cx, |buffer, _cx| {
1460                        buffer
1461                            .buffer_diagnostics(None)
1462                            .into_iter()
1463                            .min_by_key(|entry| entry.diagnostic.severity)
1464                            .map(|entry| entry.range.start)
1465                    })?
1466                    .map(|position| (closest_buffer, position));
1467            }
1468        }
1469
1470        anyhow::Ok(jump_location)
1471    }
1472
1473    fn request_prediction_with_zeta2(
1474        &mut self,
1475        project: &Entity<Project>,
1476        active_buffer: &Entity<Buffer>,
1477        active_snapshot: BufferSnapshot,
1478        position: language::Anchor,
1479        events: Vec<Arc<Event>>,
1480        trigger: PredictEditsRequestTrigger,
1481        cx: &mut Context<Self>,
1482    ) -> Task<Result<Option<EditPredictionResult>>> {
1483        let project_state = self.projects.get(&project.entity_id());
1484
1485        let index_state = project_state.and_then(|state| {
1486            state
1487                .syntax_index
1488                .as_ref()
1489                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1490        });
1491        let options = self.options.clone();
1492        let buffer_snapshotted_at = Instant::now();
1493        let Some(excerpt_path) = active_snapshot
1494            .file()
1495            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1496        else {
1497            return Task::ready(Err(anyhow!("No file path for excerpt")));
1498        };
1499        let client = self.client.clone();
1500        let llm_token = self.llm_token.clone();
1501        let app_version = AppVersion::global(cx);
1502        let worktree_snapshots = project
1503            .read(cx)
1504            .worktrees(cx)
1505            .map(|worktree| worktree.read(cx).snapshot())
1506            .collect::<Vec<_>>();
1507        let debug_tx = self.debug_tx.clone();
1508
1509        let diagnostics = active_snapshot.diagnostic_sets().clone();
1510
1511        let file = active_buffer.read(cx).file();
1512        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1513            let mut path = f.worktree.read(cx).absolutize(&f.path);
1514            if path.pop() { Some(path) } else { None }
1515        });
1516
1517        // TODO data collection
1518        let can_collect_data = file
1519            .as_ref()
1520            .map_or(false, |file| self.can_collect_file(project, file, cx));
1521
1522        let empty_context_files = HashMap::default();
1523        let context_files = project_state
1524            .and_then(|project_state| project_state.context.as_ref())
1525            .unwrap_or(&empty_context_files);
1526
1527        #[cfg(feature = "eval-support")]
1528        let parsed_fut = futures::future::join_all(
1529            context_files
1530                .keys()
1531                .map(|buffer| buffer.read(cx).parsing_idle()),
1532        );
1533
1534        let mut included_files = context_files
1535            .iter()
1536            .filter_map(|(buffer_entity, ranges)| {
1537                let buffer = buffer_entity.read(cx);
1538                Some((
1539                    buffer_entity.clone(),
1540                    buffer.snapshot(),
1541                    buffer.file()?.full_path(cx).into(),
1542                    ranges.clone(),
1543                ))
1544            })
1545            .collect::<Vec<_>>();
1546
1547        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1548            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1549        });
1550
1551        #[cfg(feature = "eval-support")]
1552        let eval_cache = self.eval_cache.clone();
1553
1554        let request_task = cx.background_spawn({
1555            let active_buffer = active_buffer.clone();
1556            async move {
1557                #[cfg(feature = "eval-support")]
1558                parsed_fut.await;
1559
1560                let index_state = if let Some(index_state) = index_state {
1561                    Some(index_state.lock_owned().await)
1562                } else {
1563                    None
1564                };
1565
1566                let cursor_offset = position.to_offset(&active_snapshot);
1567                let cursor_point = cursor_offset.to_point(&active_snapshot);
1568
1569                let before_retrieval = Instant::now();
1570
1571                let (diagnostic_groups, diagnostic_groups_truncated) =
1572                    Self::gather_nearby_diagnostics(
1573                        cursor_offset,
1574                        &diagnostics,
1575                        &active_snapshot,
1576                        options.max_diagnostic_bytes,
1577                    );
1578
1579                let cloud_request = match options.context {
1580                    ContextMode::Agentic(context_options) => {
1581                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1582                            cursor_point,
1583                            &active_snapshot,
1584                            &context_options.excerpt,
1585                            index_state.as_deref(),
1586                        ) else {
1587                            return Ok((None, None));
1588                        };
1589
1590                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1591                            ..active_snapshot.anchor_before(excerpt.range.end);
1592
1593                        if let Some(buffer_ix) =
1594                            included_files.iter().position(|(_, snapshot, _, _)| {
1595                                snapshot.remote_id() == active_snapshot.remote_id()
1596                            })
1597                        {
1598                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1599                            ranges.push(excerpt_anchor_range);
1600                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1601                            let last_ix = included_files.len() - 1;
1602                            included_files.swap(buffer_ix, last_ix);
1603                        } else {
1604                            included_files.push((
1605                                active_buffer.clone(),
1606                                active_snapshot.clone(),
1607                                excerpt_path.clone(),
1608                                vec![excerpt_anchor_range],
1609                            ));
1610                        }
1611
1612                        let included_files = included_files
1613                            .iter()
1614                            .map(|(_, snapshot, path, ranges)| {
1615                                let ranges = ranges
1616                                    .iter()
1617                                    .map(|range| {
1618                                        let point_range = range.to_point(&snapshot);
1619                                        Line(point_range.start.row)..Line(point_range.end.row)
1620                                    })
1621                                    .collect::<Vec<_>>();
1622                                let excerpts = assemble_excerpts(&snapshot, ranges);
1623                                predict_edits_v3::IncludedFile {
1624                                    path: path.clone(),
1625                                    max_row: Line(snapshot.max_point().row),
1626                                    excerpts,
1627                                }
1628                            })
1629                            .collect::<Vec<_>>();
1630
1631                        predict_edits_v3::PredictEditsRequest {
1632                            excerpt_path,
1633                            excerpt: String::new(),
1634                            excerpt_line_range: Line(0)..Line(0),
1635                            excerpt_range: 0..0,
1636                            cursor_point: predict_edits_v3::Point {
1637                                line: predict_edits_v3::Line(cursor_point.row),
1638                                column: cursor_point.column,
1639                            },
1640                            included_files,
1641                            referenced_declarations: vec![],
1642                            events,
1643                            can_collect_data,
1644                            diagnostic_groups,
1645                            diagnostic_groups_truncated,
1646                            debug_info: debug_tx.is_some(),
1647                            prompt_max_bytes: Some(options.max_prompt_bytes),
1648                            prompt_format: options.prompt_format,
1649                            // TODO [zeta2]
1650                            signatures: vec![],
1651                            excerpt_parent: None,
1652                            git_info: None,
1653                            trigger,
1654                        }
1655                    }
1656                    ContextMode::Syntax(context_options) => {
1657                        let Some(context) = EditPredictionContext::gather_context(
1658                            cursor_point,
1659                            &active_snapshot,
1660                            parent_abs_path.as_deref(),
1661                            &context_options,
1662                            index_state.as_deref(),
1663                        ) else {
1664                            return Ok((None, None));
1665                        };
1666
1667                        make_syntax_context_cloud_request(
1668                            excerpt_path,
1669                            context,
1670                            events,
1671                            can_collect_data,
1672                            diagnostic_groups,
1673                            diagnostic_groups_truncated,
1674                            None,
1675                            debug_tx.is_some(),
1676                            &worktree_snapshots,
1677                            index_state.as_deref(),
1678                            Some(options.max_prompt_bytes),
1679                            options.prompt_format,
1680                            trigger,
1681                        )
1682                    }
1683                };
1684
1685                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1686
1687                let inputs = EditPredictionInputs {
1688                    included_files: cloud_request.included_files,
1689                    events: cloud_request.events,
1690                    cursor_point: cloud_request.cursor_point,
1691                    cursor_path: cloud_request.excerpt_path,
1692                };
1693
1694                let retrieval_time = Instant::now() - before_retrieval;
1695
1696                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1697                    let (response_tx, response_rx) = oneshot::channel();
1698
1699                    debug_tx
1700                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1701                            ZetaEditPredictionDebugInfo {
1702                                inputs: inputs.clone(),
1703                                retrieval_time,
1704                                buffer: active_buffer.downgrade(),
1705                                local_prompt: match prompt_result.as_ref() {
1706                                    Ok((prompt, _)) => Ok(prompt.clone()),
1707                                    Err(err) => Err(err.to_string()),
1708                                },
1709                                position,
1710                                response_rx,
1711                            },
1712                        ))
1713                        .ok();
1714                    Some(response_tx)
1715                } else {
1716                    None
1717                };
1718
1719                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1720                    if let Some(debug_response_tx) = debug_response_tx {
1721                        debug_response_tx
1722                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1723                            .ok();
1724                    }
1725                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1726                }
1727
1728                let (prompt, _) = prompt_result?;
1729                let generation_params =
1730                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1731                let request = open_ai::Request {
1732                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1733                    messages: vec![open_ai::RequestMessage::User {
1734                        content: open_ai::MessageContent::Plain(prompt),
1735                    }],
1736                    stream: false,
1737                    max_completion_tokens: None,
1738                    stop: generation_params.stop.unwrap_or_default(),
1739                    temperature: generation_params.temperature.unwrap_or(0.7),
1740                    tool_choice: None,
1741                    parallel_tool_calls: None,
1742                    tools: vec![],
1743                    prompt_cache_key: None,
1744                    reasoning_effort: None,
1745                };
1746
1747                log::trace!("Sending edit prediction request");
1748
1749                let before_request = Instant::now();
1750                let response = Self::send_raw_llm_request(
1751                    request,
1752                    client,
1753                    llm_token,
1754                    app_version,
1755                    #[cfg(feature = "eval-support")]
1756                    eval_cache,
1757                    #[cfg(feature = "eval-support")]
1758                    EvalCacheEntryKind::Prediction,
1759                )
1760                .await;
1761                let received_response_at = Instant::now();
1762                let request_time = received_response_at - before_request;
1763
1764                log::trace!("Got edit prediction response");
1765
1766                if let Some(debug_response_tx) = debug_response_tx {
1767                    debug_response_tx
1768                        .send((
1769                            response
1770                                .as_ref()
1771                                .map_err(|err| err.to_string())
1772                                .map(|response| response.0.clone()),
1773                            request_time,
1774                        ))
1775                        .ok();
1776                }
1777
1778                let (res, usage) = response?;
1779                let request_id = EditPredictionId(res.id.clone().into());
1780                let Some(mut output_text) = text_from_response(res) else {
1781                    return Ok((Some((request_id, None)), usage));
1782                };
1783
1784                if output_text.contains(CURSOR_MARKER) {
1785                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1786                    output_text = output_text.replace(CURSOR_MARKER, "");
1787                }
1788
1789                let get_buffer_from_context = |path: &Path| {
1790                    included_files
1791                        .iter()
1792                        .find_map(|(_, buffer, probe_path, ranges)| {
1793                            if probe_path.as_ref() == path {
1794                                Some((buffer, ranges.as_slice()))
1795                            } else {
1796                                None
1797                            }
1798                        })
1799                };
1800
1801                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1802                    PromptFormat::NumLinesUniDiff => {
1803                        // TODO: Implement parsing of multi-file diffs
1804                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1805                    }
1806                    PromptFormat::Minimal
1807                    | PromptFormat::MinimalQwen
1808                    | PromptFormat::SeedCoder1120 => {
1809                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1810                            let edits = vec![];
1811                            (&active_snapshot, edits)
1812                        } else {
1813                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1814                        }
1815                    }
1816                    PromptFormat::OldTextNewText => {
1817                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1818                            .await?
1819                    }
1820                    _ => {
1821                        bail!("unsupported prompt format {}", options.prompt_format)
1822                    }
1823                };
1824
1825                let edited_buffer = included_files
1826                    .iter()
1827                    .find_map(|(buffer, snapshot, _, _)| {
1828                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1829                            Some(buffer.clone())
1830                        } else {
1831                            None
1832                        }
1833                    })
1834                    .context("Failed to find buffer in included_buffers")?;
1835
1836                anyhow::Ok((
1837                    Some((
1838                        request_id,
1839                        Some((
1840                            inputs,
1841                            edited_buffer,
1842                            edited_buffer_snapshot.clone(),
1843                            edits,
1844                            received_response_at,
1845                        )),
1846                    )),
1847                    usage,
1848                ))
1849            }
1850        });
1851
1852        cx.spawn({
1853            async move |this, cx| {
1854                let Some((id, prediction)) =
1855                    Self::handle_api_response(&this, request_task.await, cx)?
1856                else {
1857                    return Ok(None);
1858                };
1859
1860                let Some((
1861                    inputs,
1862                    edited_buffer,
1863                    edited_buffer_snapshot,
1864                    edits,
1865                    received_response_at,
1866                )) = prediction
1867                else {
1868                    return Ok(Some(EditPredictionResult {
1869                        id,
1870                        prediction: Err(EditPredictionRejectReason::Empty),
1871                    }));
1872                };
1873
1874                // TODO telemetry: duration, etc
1875                Ok(Some(
1876                    EditPredictionResult::new(
1877                        id,
1878                        &edited_buffer,
1879                        &edited_buffer_snapshot,
1880                        edits.into(),
1881                        buffer_snapshotted_at,
1882                        received_response_at,
1883                        inputs,
1884                        cx,
1885                    )
1886                    .await,
1887                ))
1888            }
1889        })
1890    }
1891
1892    async fn send_raw_llm_request(
1893        request: open_ai::Request,
1894        client: Arc<Client>,
1895        llm_token: LlmApiToken,
1896        app_version: Version,
1897        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1898        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1899    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1900        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1901            http_client::Url::parse(&predict_edits_url)?
1902        } else {
1903            client
1904                .http_client()
1905                .build_zed_llm_url("/predict_edits/raw", &[])?
1906        };
1907
1908        #[cfg(feature = "eval-support")]
1909        let cache_key = if let Some(cache) = eval_cache {
1910            use collections::FxHasher;
1911            use std::hash::{Hash, Hasher};
1912
1913            let mut hasher = FxHasher::default();
1914            url.hash(&mut hasher);
1915            let request_str = serde_json::to_string_pretty(&request)?;
1916            request_str.hash(&mut hasher);
1917            let hash = hasher.finish();
1918
1919            let key = (eval_cache_kind, hash);
1920            if let Some(response_str) = cache.read(key) {
1921                return Ok((serde_json::from_str(&response_str)?, None));
1922            }
1923
1924            Some((cache, request_str, key))
1925        } else {
1926            None
1927        };
1928
1929        let (response, usage) = Self::send_api_request(
1930            |builder| {
1931                let req = builder
1932                    .uri(url.as_ref())
1933                    .body(serde_json::to_string(&request)?.into());
1934                Ok(req?)
1935            },
1936            client,
1937            llm_token,
1938            app_version,
1939        )
1940        .await?;
1941
1942        #[cfg(feature = "eval-support")]
1943        if let Some((cache, request, key)) = cache_key {
1944            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1945        }
1946
1947        Ok((response, usage))
1948    }
1949
1950    fn handle_api_response<T>(
1951        this: &WeakEntity<Self>,
1952        response: Result<(T, Option<EditPredictionUsage>)>,
1953        cx: &mut gpui::AsyncApp,
1954    ) -> Result<T> {
1955        match response {
1956            Ok((data, usage)) => {
1957                if let Some(usage) = usage {
1958                    this.update(cx, |this, cx| {
1959                        this.user_store.update(cx, |user_store, cx| {
1960                            user_store.update_edit_prediction_usage(usage, cx);
1961                        });
1962                    })
1963                    .ok();
1964                }
1965                Ok(data)
1966            }
1967            Err(err) => {
1968                if err.is::<ZedUpdateRequiredError>() {
1969                    cx.update(|cx| {
1970                        this.update(cx, |this, _cx| {
1971                            this.update_required = true;
1972                        })
1973                        .ok();
1974
1975                        let error_message: SharedString = err.to_string().into();
1976                        show_app_notification(
1977                            NotificationId::unique::<ZedUpdateRequiredError>(),
1978                            cx,
1979                            move |cx| {
1980                                cx.new(|cx| {
1981                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1982                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1983                                })
1984                            },
1985                        );
1986                    })
1987                    .ok();
1988                }
1989                Err(err)
1990            }
1991        }
1992    }
1993
1994    async fn send_api_request<Res>(
1995        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1996        client: Arc<Client>,
1997        llm_token: LlmApiToken,
1998        app_version: Version,
1999    ) -> Result<(Res, Option<EditPredictionUsage>)>
2000    where
2001        Res: DeserializeOwned,
2002    {
2003        let http_client = client.http_client();
2004        let mut token = llm_token.acquire(&client).await?;
2005        let mut did_retry = false;
2006
2007        loop {
2008            let request_builder = http_client::Request::builder().method(Method::POST);
2009
2010            let request = build(
2011                request_builder
2012                    .header("Content-Type", "application/json")
2013                    .header("Authorization", format!("Bearer {}", token))
2014                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2015            )?;
2016
2017            let mut response = http_client.send(request).await?;
2018
2019            if let Some(minimum_required_version) = response
2020                .headers()
2021                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2022                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2023            {
2024                anyhow::ensure!(
2025                    app_version >= minimum_required_version,
2026                    ZedUpdateRequiredError {
2027                        minimum_version: minimum_required_version
2028                    }
2029                );
2030            }
2031
2032            if response.status().is_success() {
2033                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2034
2035                let mut body = Vec::new();
2036                response.body_mut().read_to_end(&mut body).await?;
2037                return Ok((serde_json::from_slice(&body)?, usage));
2038            } else if !did_retry
2039                && response
2040                    .headers()
2041                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2042                    .is_some()
2043            {
2044                did_retry = true;
2045                token = llm_token.refresh(&client).await?;
2046            } else {
2047                let mut body = String::new();
2048                response.body_mut().read_to_string(&mut body).await?;
2049                anyhow::bail!(
2050                    "Request failed with status: {:?}\nBody: {}",
2051                    response.status(),
2052                    body
2053                );
2054            }
2055        }
2056    }
2057
2058    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2059    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2060
2061    // Refresh the related excerpts when the user just beguns editing after
2062    // an idle period, and after they pause editing.
2063    fn refresh_context_if_needed(
2064        &mut self,
2065        project: &Entity<Project>,
2066        buffer: &Entity<language::Buffer>,
2067        cursor_position: language::Anchor,
2068        cx: &mut Context<Self>,
2069    ) {
2070        if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2071            return;
2072        }
2073
2074        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2075            return;
2076        }
2077
2078        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2079            return;
2080        };
2081
2082        let now = Instant::now();
2083        let was_idle = zeta_project
2084            .refresh_context_timestamp
2085            .map_or(true, |timestamp| {
2086                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2087            });
2088        zeta_project.refresh_context_timestamp = Some(now);
2089        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2090            let buffer = buffer.clone();
2091            let project = project.clone();
2092            async move |this, cx| {
2093                if was_idle {
2094                    log::debug!("refetching edit prediction context after idle");
2095                } else {
2096                    cx.background_executor()
2097                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2098                        .await;
2099                    log::debug!("refetching edit prediction context after pause");
2100                }
2101                this.update(cx, |this, cx| {
2102                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2103
2104                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2105                        zeta_project.refresh_context_task = Some(task.log_err());
2106                    };
2107                })
2108                .ok()
2109            }
2110        }));
2111    }
2112
2113    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2114    // and avoid spawning more than one concurrent task.
2115    pub fn refresh_context(
2116        &mut self,
2117        project: Entity<Project>,
2118        buffer: Entity<language::Buffer>,
2119        cursor_position: language::Anchor,
2120        cx: &mut Context<Self>,
2121    ) -> Task<Result<()>> {
2122        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2123            return Task::ready(anyhow::Ok(()));
2124        };
2125
2126        let ContextMode::Agentic(options) = &self.options().context else {
2127            return Task::ready(anyhow::Ok(()));
2128        };
2129
2130        let snapshot = buffer.read(cx).snapshot();
2131        let cursor_point = cursor_position.to_point(&snapshot);
2132        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2133            cursor_point,
2134            &snapshot,
2135            &options.excerpt,
2136            None,
2137        ) else {
2138            return Task::ready(Ok(()));
2139        };
2140
2141        let app_version = AppVersion::global(cx);
2142        let client = self.client.clone();
2143        let llm_token = self.llm_token.clone();
2144        let debug_tx = self.debug_tx.clone();
2145        let current_file_path: Arc<Path> = snapshot
2146            .file()
2147            .map(|f| f.full_path(cx).into())
2148            .unwrap_or_else(|| Path::new("untitled").into());
2149
2150        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2151            predict_edits_v3::PlanContextRetrievalRequest {
2152                excerpt: cursor_excerpt.text(&snapshot).body,
2153                excerpt_path: current_file_path,
2154                excerpt_line_range: cursor_excerpt.line_range,
2155                cursor_file_max_row: Line(snapshot.max_point().row),
2156                events: zeta_project.events(cx),
2157            },
2158        ) {
2159            Ok(prompt) => prompt,
2160            Err(err) => {
2161                return Task::ready(Err(err));
2162            }
2163        };
2164
2165        if let Some(debug_tx) = &debug_tx {
2166            debug_tx
2167                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2168                    ZetaContextRetrievalStartedDebugInfo {
2169                        project: project.clone(),
2170                        timestamp: Instant::now(),
2171                        search_prompt: prompt.clone(),
2172                    },
2173                ))
2174                .ok();
2175        }
2176
2177        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2178            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2179                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2180            );
2181
2182            let description = schema
2183                .get("description")
2184                .and_then(|description| description.as_str())
2185                .unwrap()
2186                .to_string();
2187
2188            (schema.into(), description)
2189        });
2190
2191        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2192
2193        let request = open_ai::Request {
2194            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2195            messages: vec![open_ai::RequestMessage::User {
2196                content: open_ai::MessageContent::Plain(prompt),
2197            }],
2198            stream: false,
2199            max_completion_tokens: None,
2200            stop: Default::default(),
2201            temperature: 0.7,
2202            tool_choice: None,
2203            parallel_tool_calls: None,
2204            tools: vec![open_ai::ToolDefinition::Function {
2205                function: FunctionDefinition {
2206                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2207                    description: Some(tool_description),
2208                    parameters: Some(tool_schema),
2209                },
2210            }],
2211            prompt_cache_key: None,
2212            reasoning_effort: None,
2213        };
2214
2215        #[cfg(feature = "eval-support")]
2216        let eval_cache = self.eval_cache.clone();
2217
2218        cx.spawn(async move |this, cx| {
2219            log::trace!("Sending search planning request");
2220            let response = Self::send_raw_llm_request(
2221                request,
2222                client,
2223                llm_token,
2224                app_version,
2225                #[cfg(feature = "eval-support")]
2226                eval_cache.clone(),
2227                #[cfg(feature = "eval-support")]
2228                EvalCacheEntryKind::Context,
2229            )
2230            .await;
2231            let mut response = Self::handle_api_response(&this, response, cx)?;
2232            log::trace!("Got search planning response");
2233
2234            let choice = response
2235                .choices
2236                .pop()
2237                .context("No choices in retrieval response")?;
2238            let open_ai::RequestMessage::Assistant {
2239                content: _,
2240                tool_calls,
2241            } = choice.message
2242            else {
2243                anyhow::bail!("Retrieval response didn't include an assistant message");
2244            };
2245
2246            let mut queries: Vec<SearchToolQuery> = Vec::new();
2247            for tool_call in tool_calls {
2248                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2249                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2250                    log::warn!(
2251                        "Context retrieval response tried to call an unknown tool: {}",
2252                        function.name
2253                    );
2254
2255                    continue;
2256                }
2257
2258                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2259                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2260                queries.extend(input.queries);
2261            }
2262
2263            if let Some(debug_tx) = &debug_tx {
2264                debug_tx
2265                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2266                        ZetaSearchQueryDebugInfo {
2267                            project: project.clone(),
2268                            timestamp: Instant::now(),
2269                            search_queries: queries.clone(),
2270                        },
2271                    ))
2272                    .ok();
2273            }
2274
2275            log::trace!("Running retrieval search: {queries:#?}");
2276
2277            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2278                queries,
2279                project.clone(),
2280                #[cfg(feature = "eval-support")]
2281                eval_cache,
2282                cx,
2283            )
2284            .await;
2285
2286            log::trace!("Search queries executed");
2287
2288            if let Some(debug_tx) = &debug_tx {
2289                debug_tx
2290                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2291                        ZetaContextRetrievalDebugInfo {
2292                            project: project.clone(),
2293                            timestamp: Instant::now(),
2294                        },
2295                    ))
2296                    .ok();
2297            }
2298
2299            this.update(cx, |this, _cx| {
2300                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2301                    return Ok(());
2302                };
2303                zeta_project.refresh_context_task.take();
2304                if let Some(debug_tx) = &this.debug_tx {
2305                    debug_tx
2306                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2307                            ZetaContextRetrievalDebugInfo {
2308                                project,
2309                                timestamp: Instant::now(),
2310                            },
2311                        ))
2312                        .ok();
2313                }
2314                match related_excerpts_result {
2315                    Ok(excerpts) => {
2316                        zeta_project.context = Some(excerpts);
2317                        Ok(())
2318                    }
2319                    Err(error) => Err(error),
2320                }
2321            })?
2322        })
2323    }
2324
2325    pub fn set_context(
2326        &mut self,
2327        project: Entity<Project>,
2328        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2329    ) {
2330        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2331            zeta_project.context = Some(context);
2332        }
2333    }
2334
2335    fn gather_nearby_diagnostics(
2336        cursor_offset: usize,
2337        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2338        snapshot: &BufferSnapshot,
2339        max_diagnostics_bytes: usize,
2340    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2341        // TODO: Could make this more efficient
2342        let mut diagnostic_groups = Vec::new();
2343        for (language_server_id, diagnostics) in diagnostic_sets {
2344            let mut groups = Vec::new();
2345            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2346            diagnostic_groups.extend(
2347                groups
2348                    .into_iter()
2349                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2350            );
2351        }
2352
2353        // sort by proximity to cursor
2354        diagnostic_groups.sort_by_key(|group| {
2355            let range = &group.entries[group.primary_ix].range;
2356            if range.start >= cursor_offset {
2357                range.start - cursor_offset
2358            } else if cursor_offset >= range.end {
2359                cursor_offset - range.end
2360            } else {
2361                (cursor_offset - range.start).min(range.end - cursor_offset)
2362            }
2363        });
2364
2365        let mut results = Vec::new();
2366        let mut diagnostic_groups_truncated = false;
2367        let mut diagnostics_byte_count = 0;
2368        for group in diagnostic_groups {
2369            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2370            diagnostics_byte_count += raw_value.get().len();
2371            if diagnostics_byte_count > max_diagnostics_bytes {
2372                diagnostic_groups_truncated = true;
2373                break;
2374            }
2375            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2376        }
2377
2378        (results, diagnostic_groups_truncated)
2379    }
2380
2381    // TODO: Dedupe with similar code in request_prediction?
2382    pub fn cloud_request_for_zeta_cli(
2383        &mut self,
2384        project: &Entity<Project>,
2385        buffer: &Entity<Buffer>,
2386        position: language::Anchor,
2387        cx: &mut Context<Self>,
2388    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2389        let project_state = self.projects.get(&project.entity_id());
2390
2391        let index_state = project_state.and_then(|state| {
2392            state
2393                .syntax_index
2394                .as_ref()
2395                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2396        });
2397        let options = self.options.clone();
2398        let snapshot = buffer.read(cx).snapshot();
2399        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2400            return Task::ready(Err(anyhow!("No file path for excerpt")));
2401        };
2402        let worktree_snapshots = project
2403            .read(cx)
2404            .worktrees(cx)
2405            .map(|worktree| worktree.read(cx).snapshot())
2406            .collect::<Vec<_>>();
2407
2408        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2409            let mut path = f.worktree.read(cx).absolutize(&f.path);
2410            if path.pop() { Some(path) } else { None }
2411        });
2412
2413        cx.background_spawn(async move {
2414            let index_state = if let Some(index_state) = index_state {
2415                Some(index_state.lock_owned().await)
2416            } else {
2417                None
2418            };
2419
2420            let cursor_point = position.to_point(&snapshot);
2421
2422            let debug_info = true;
2423            EditPredictionContext::gather_context(
2424                cursor_point,
2425                &snapshot,
2426                parent_abs_path.as_deref(),
2427                match &options.context {
2428                    ContextMode::Agentic(_) => {
2429                        // TODO
2430                        panic!("Llm mode not supported in zeta cli yet");
2431                    }
2432                    ContextMode::Syntax(edit_prediction_context_options) => {
2433                        edit_prediction_context_options
2434                    }
2435                },
2436                index_state.as_deref(),
2437            )
2438            .context("Failed to select excerpt")
2439            .map(|context| {
2440                make_syntax_context_cloud_request(
2441                    excerpt_path.into(),
2442                    context,
2443                    // TODO pass everything
2444                    Vec::new(),
2445                    false,
2446                    Vec::new(),
2447                    false,
2448                    None,
2449                    debug_info,
2450                    &worktree_snapshots,
2451                    index_state.as_deref(),
2452                    Some(options.max_prompt_bytes),
2453                    options.prompt_format,
2454                    PredictEditsRequestTrigger::Other,
2455                )
2456            })
2457        })
2458    }
2459
2460    pub fn wait_for_initial_indexing(
2461        &mut self,
2462        project: &Entity<Project>,
2463        cx: &mut Context<Self>,
2464    ) -> Task<Result<()>> {
2465        let zeta_project = self.get_or_init_zeta_project(project, cx);
2466        if let Some(syntax_index) = &zeta_project.syntax_index {
2467            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2468        } else {
2469            Task::ready(Ok(()))
2470        }
2471    }
2472
2473    fn is_file_open_source(
2474        &self,
2475        project: &Entity<Project>,
2476        file: &Arc<dyn File>,
2477        cx: &App,
2478    ) -> bool {
2479        if !file.is_local() || file.is_private() {
2480            return false;
2481        }
2482        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2483            return false;
2484        };
2485        zeta_project
2486            .license_detection_watchers
2487            .get(&file.worktree_id(cx))
2488            .as_ref()
2489            .is_some_and(|watcher| watcher.is_project_open_source())
2490    }
2491
2492    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2493        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2494    }
2495
2496    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2497        if !self.data_collection_choice.is_enabled() {
2498            return false;
2499        }
2500        events.iter().all(|event| {
2501            matches!(
2502                event.as_ref(),
2503                Event::BufferChange {
2504                    in_open_source_repo: true,
2505                    ..
2506                }
2507            )
2508        })
2509    }
2510
2511    fn load_data_collection_choice() -> DataCollectionChoice {
2512        let choice = KEY_VALUE_STORE
2513            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2514            .log_err()
2515            .flatten();
2516
2517        match choice.as_deref() {
2518            Some("true") => DataCollectionChoice::Enabled,
2519            Some("false") => DataCollectionChoice::Disabled,
2520            Some(_) => {
2521                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2522                DataCollectionChoice::NotAnswered
2523            }
2524            None => DataCollectionChoice::NotAnswered,
2525        }
2526    }
2527
2528    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2529        self.shown_predictions.iter()
2530    }
2531
2532    pub fn shown_completions_len(&self) -> usize {
2533        self.shown_predictions.len()
2534    }
2535
2536    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2537        self.rated_predictions.contains(id)
2538    }
2539
2540    pub fn rate_prediction(
2541        &mut self,
2542        prediction: &EditPrediction,
2543        rating: EditPredictionRating,
2544        feedback: String,
2545        cx: &mut Context<Self>,
2546    ) {
2547        self.rated_predictions.insert(prediction.id.clone());
2548        telemetry::event!(
2549            "Edit Prediction Rated",
2550            rating,
2551            inputs = prediction.inputs,
2552            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2553            feedback
2554        );
2555        self.client.telemetry().flush_events().detach();
2556        cx.notify();
2557    }
2558}
2559
2560pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2561    let choice = res.choices.pop()?;
2562    let output_text = match choice.message {
2563        open_ai::RequestMessage::Assistant {
2564            content: Some(open_ai::MessageContent::Plain(content)),
2565            ..
2566        } => content,
2567        open_ai::RequestMessage::Assistant {
2568            content: Some(open_ai::MessageContent::Multipart(mut content)),
2569            ..
2570        } => {
2571            if content.is_empty() {
2572                log::error!("No output from Baseten completion response");
2573                return None;
2574            }
2575
2576            match content.remove(0) {
2577                open_ai::MessagePart::Text { text } => text,
2578                open_ai::MessagePart::Image { .. } => {
2579                    log::error!("Expected text, got an image");
2580                    return None;
2581                }
2582            }
2583        }
2584        _ => {
2585            log::error!("Invalid response message: {:?}", choice.message);
2586            return None;
2587        }
2588    };
2589    Some(output_text)
2590}
2591
2592#[derive(Error, Debug)]
2593#[error(
2594    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2595)]
2596pub struct ZedUpdateRequiredError {
2597    minimum_version: Version,
2598}
2599
2600fn make_syntax_context_cloud_request(
2601    excerpt_path: Arc<Path>,
2602    context: EditPredictionContext,
2603    events: Vec<Arc<predict_edits_v3::Event>>,
2604    can_collect_data: bool,
2605    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2606    diagnostic_groups_truncated: bool,
2607    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2608    debug_info: bool,
2609    worktrees: &Vec<worktree::Snapshot>,
2610    index_state: Option<&SyntaxIndexState>,
2611    prompt_max_bytes: Option<usize>,
2612    prompt_format: PromptFormat,
2613    trigger: PredictEditsRequestTrigger,
2614) -> predict_edits_v3::PredictEditsRequest {
2615    let mut signatures = Vec::new();
2616    let mut declaration_to_signature_index = HashMap::default();
2617    let mut referenced_declarations = Vec::new();
2618
2619    for snippet in context.declarations {
2620        let project_entry_id = snippet.declaration.project_entry_id();
2621        let Some(path) = worktrees.iter().find_map(|worktree| {
2622            worktree.entry_for_id(project_entry_id).map(|entry| {
2623                let mut full_path = RelPathBuf::new();
2624                full_path.push(worktree.root_name());
2625                full_path.push(&entry.path);
2626                full_path
2627            })
2628        }) else {
2629            continue;
2630        };
2631
2632        let parent_index = index_state.and_then(|index_state| {
2633            snippet.declaration.parent().and_then(|parent| {
2634                add_signature(
2635                    parent,
2636                    &mut declaration_to_signature_index,
2637                    &mut signatures,
2638                    index_state,
2639                )
2640            })
2641        });
2642
2643        let (text, text_is_truncated) = snippet.declaration.item_text();
2644        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2645            path: path.as_std_path().into(),
2646            text: text.into(),
2647            range: snippet.declaration.item_line_range(),
2648            text_is_truncated,
2649            signature_range: snippet.declaration.signature_range_in_item_text(),
2650            parent_index,
2651            signature_score: snippet.score(DeclarationStyle::Signature),
2652            declaration_score: snippet.score(DeclarationStyle::Declaration),
2653            score_components: snippet.components,
2654        });
2655    }
2656
2657    let excerpt_parent = index_state.and_then(|index_state| {
2658        context
2659            .excerpt
2660            .parent_declarations
2661            .last()
2662            .and_then(|(parent, _)| {
2663                add_signature(
2664                    *parent,
2665                    &mut declaration_to_signature_index,
2666                    &mut signatures,
2667                    index_state,
2668                )
2669            })
2670    });
2671
2672    predict_edits_v3::PredictEditsRequest {
2673        excerpt_path,
2674        excerpt: context.excerpt_text.body,
2675        excerpt_line_range: context.excerpt.line_range,
2676        excerpt_range: context.excerpt.range,
2677        cursor_point: predict_edits_v3::Point {
2678            line: predict_edits_v3::Line(context.cursor_point.row),
2679            column: context.cursor_point.column,
2680        },
2681        referenced_declarations,
2682        included_files: vec![],
2683        signatures,
2684        excerpt_parent,
2685        events,
2686        can_collect_data,
2687        diagnostic_groups,
2688        diagnostic_groups_truncated,
2689        git_info,
2690        debug_info,
2691        prompt_max_bytes,
2692        prompt_format,
2693        trigger,
2694    }
2695}
2696
2697fn add_signature(
2698    declaration_id: DeclarationId,
2699    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2700    signatures: &mut Vec<Signature>,
2701    index: &SyntaxIndexState,
2702) -> Option<usize> {
2703    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2704        return Some(*signature_index);
2705    }
2706    let Some(parent_declaration) = index.declaration(declaration_id) else {
2707        log::error!("bug: missing parent declaration");
2708        return None;
2709    };
2710    let parent_index = parent_declaration.parent().and_then(|parent| {
2711        add_signature(parent, declaration_to_signature_index, signatures, index)
2712    });
2713    let (text, text_is_truncated) = parent_declaration.signature_text();
2714    let signature_index = signatures.len();
2715    signatures.push(Signature {
2716        text: text.into(),
2717        text_is_truncated,
2718        parent_index,
2719        range: parent_declaration.signature_line_range(),
2720    });
2721    declaration_to_signature_index.insert(declaration_id, signature_index);
2722    Some(signature_index)
2723}
2724
2725#[cfg(feature = "eval-support")]
2726pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2727
2728#[cfg(feature = "eval-support")]
2729#[derive(Debug, Clone, Copy, PartialEq)]
2730pub enum EvalCacheEntryKind {
2731    Context,
2732    Search,
2733    Prediction,
2734}
2735
2736#[cfg(feature = "eval-support")]
2737impl std::fmt::Display for EvalCacheEntryKind {
2738    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2739        match self {
2740            EvalCacheEntryKind::Search => write!(f, "search"),
2741            EvalCacheEntryKind::Context => write!(f, "context"),
2742            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2743        }
2744    }
2745}
2746
2747#[cfg(feature = "eval-support")]
2748pub trait EvalCache: Send + Sync {
2749    fn read(&self, key: EvalCacheKey) -> Option<String>;
2750    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2751}
2752
2753#[derive(Debug, Clone, Copy)]
2754pub enum DataCollectionChoice {
2755    NotAnswered,
2756    Enabled,
2757    Disabled,
2758}
2759
2760impl DataCollectionChoice {
2761    pub fn is_enabled(self) -> bool {
2762        match self {
2763            Self::Enabled => true,
2764            Self::NotAnswered | Self::Disabled => false,
2765        }
2766    }
2767
2768    pub fn is_answered(self) -> bool {
2769        match self {
2770            Self::Enabled | Self::Disabled => true,
2771            Self::NotAnswered => false,
2772        }
2773    }
2774
2775    #[must_use]
2776    pub fn toggle(&self) -> DataCollectionChoice {
2777        match self {
2778            Self::Enabled => Self::Disabled,
2779            Self::Disabled => Self::Enabled,
2780            Self::NotAnswered => Self::Enabled,
2781        }
2782    }
2783}
2784
2785impl From<bool> for DataCollectionChoice {
2786    fn from(value: bool) -> Self {
2787        match value {
2788            true => DataCollectionChoice::Enabled,
2789            false => DataCollectionChoice::Disabled,
2790        }
2791    }
2792}
2793
2794struct ZedPredictUpsell;
2795
2796impl Dismissable for ZedPredictUpsell {
2797    const KEY: &'static str = "dismissed-edit-predict-upsell";
2798
2799    fn dismissed() -> bool {
2800        // To make this backwards compatible with older versions of Zed, we
2801        // check if the user has seen the previous Edit Prediction Onboarding
2802        // before, by checking the data collection choice which was written to
2803        // the database once the user clicked on "Accept and Enable"
2804        if KEY_VALUE_STORE
2805            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2806            .log_err()
2807            .is_some_and(|s| s.is_some())
2808        {
2809            return true;
2810        }
2811
2812        KEY_VALUE_STORE
2813            .read_kvp(Self::KEY)
2814            .log_err()
2815            .is_some_and(|s| s.is_some())
2816    }
2817}
2818
2819pub fn should_show_upsell_modal() -> bool {
2820    !ZedPredictUpsell::dismissed()
2821}
2822
2823pub fn init(cx: &mut App) {
2824    feature_gate_predict_edits_actions(cx);
2825
2826    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2827        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2828            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2829                RatePredictionsModal::toggle(workspace, window, cx);
2830            }
2831        });
2832
2833        workspace.register_action(
2834            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2835                ZedPredictModal::toggle(
2836                    workspace,
2837                    workspace.user_store().clone(),
2838                    workspace.client().clone(),
2839                    window,
2840                    cx,
2841                )
2842            },
2843        );
2844
2845        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2846            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2847                settings
2848                    .project
2849                    .all_languages
2850                    .features
2851                    .get_or_insert_default()
2852                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2853            });
2854        });
2855    })
2856    .detach();
2857}
2858
2859fn feature_gate_predict_edits_actions(cx: &mut App) {
2860    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2861    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2862    let zeta_all_action_types = [
2863        TypeId::of::<RateCompletions>(),
2864        TypeId::of::<ResetOnboarding>(),
2865        zed_actions::OpenZedPredictOnboarding.type_id(),
2866        TypeId::of::<ClearHistory>(),
2867        TypeId::of::<ThumbsUpActivePrediction>(),
2868        TypeId::of::<ThumbsDownActivePrediction>(),
2869        TypeId::of::<NextEdit>(),
2870        TypeId::of::<PreviousEdit>(),
2871    ];
2872
2873    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2874        filter.hide_action_types(&rate_completion_action_types);
2875        filter.hide_action_types(&reset_onboarding_action_types);
2876        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2877    });
2878
2879    cx.observe_global::<SettingsStore>(move |cx| {
2880        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2881        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2882
2883        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2884            if is_ai_disabled {
2885                filter.hide_action_types(&zeta_all_action_types);
2886            } else if has_feature_flag {
2887                filter.show_action_types(&rate_completion_action_types);
2888            } else {
2889                filter.hide_action_types(&rate_completion_action_types);
2890            }
2891        });
2892    })
2893    .detach();
2894
2895    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2896        if !DisableAiSettings::get_global(cx).disable_ai {
2897            if is_enabled {
2898                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2899                    filter.show_action_types(&rate_completion_action_types);
2900                });
2901            } else {
2902                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2903                    filter.hide_action_types(&rate_completion_action_types);
2904                });
2905            }
2906        }
2907    })
2908    .detach();
2909}
2910
2911#[cfg(test)]
2912mod tests {
2913    use std::{path::Path, sync::Arc, time::Duration};
2914
2915    use client::UserStore;
2916    use clock::FakeSystemClock;
2917    use cloud_llm_client::{
2918        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2919    };
2920    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2921    use futures::{
2922        AsyncReadExt, StreamExt,
2923        channel::{mpsc, oneshot},
2924    };
2925    use gpui::{
2926        Entity, TestAppContext,
2927        http_client::{FakeHttpClient, Response},
2928        prelude::*,
2929    };
2930    use indoc::indoc;
2931    use language::OffsetRangeExt as _;
2932    use open_ai::Usage;
2933    use pretty_assertions::{assert_eq, assert_matches};
2934    use project::{FakeFs, Project};
2935    use serde_json::json;
2936    use settings::SettingsStore;
2937    use util::path;
2938    use uuid::Uuid;
2939
2940    use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
2941
2942    #[gpui::test]
2943    async fn test_current_state(cx: &mut TestAppContext) {
2944        let (zeta, mut requests) = init_test(cx);
2945        let fs = FakeFs::new(cx.executor());
2946        fs.insert_tree(
2947            "/root",
2948            json!({
2949                "1.txt": "Hello!\nHow\nBye\n",
2950                "2.txt": "Hola!\nComo\nAdios\n"
2951            }),
2952        )
2953        .await;
2954        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2955
2956        zeta.update(cx, |zeta, cx| {
2957            zeta.register_project(&project, cx);
2958        });
2959
2960        let buffer1 = project
2961            .update(cx, |project, cx| {
2962                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2963                project.open_buffer(path, cx)
2964            })
2965            .await
2966            .unwrap();
2967        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2968        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2969
2970        // Prediction for current file
2971
2972        zeta.update(cx, |zeta, cx| {
2973            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2974        });
2975        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2976
2977        respond_tx
2978            .send(model_response(indoc! {r"
2979                --- a/root/1.txt
2980                +++ b/root/1.txt
2981                @@ ... @@
2982                 Hello!
2983                -How
2984                +How are you?
2985                 Bye
2986            "}))
2987            .unwrap();
2988
2989        cx.run_until_parked();
2990
2991        zeta.read_with(cx, |zeta, cx| {
2992            let prediction = zeta
2993                .current_prediction_for_buffer(&buffer1, &project, cx)
2994                .unwrap();
2995            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2996        });
2997
2998        // Context refresh
2999        let refresh_task = zeta.update(cx, |zeta, cx| {
3000            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
3001        });
3002        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3003        respond_tx
3004            .send(open_ai::Response {
3005                id: Uuid::new_v4().to_string(),
3006                object: "response".into(),
3007                created: 0,
3008                model: "model".into(),
3009                choices: vec![open_ai::Choice {
3010                    index: 0,
3011                    message: open_ai::RequestMessage::Assistant {
3012                        content: None,
3013                        tool_calls: vec![open_ai::ToolCall {
3014                            id: "search".into(),
3015                            content: open_ai::ToolCallContent::Function {
3016                                function: open_ai::FunctionContent {
3017                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3018                                        .to_string(),
3019                                    arguments: serde_json::to_string(&SearchToolInput {
3020                                        queries: Box::new([SearchToolQuery {
3021                                            glob: "root/2.txt".to_string(),
3022                                            syntax_node: vec![],
3023                                            content: Some(".".into()),
3024                                        }]),
3025                                    })
3026                                    .unwrap(),
3027                                },
3028                            },
3029                        }],
3030                    },
3031                    finish_reason: None,
3032                }],
3033                usage: Usage {
3034                    prompt_tokens: 0,
3035                    completion_tokens: 0,
3036                    total_tokens: 0,
3037                },
3038            })
3039            .unwrap();
3040        refresh_task.await.unwrap();
3041
3042        zeta.update(cx, |zeta, _cx| {
3043            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
3044        });
3045
3046        // Prediction for another file
3047        zeta.update(cx, |zeta, cx| {
3048            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3049        });
3050        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3051        respond_tx
3052            .send(model_response(indoc! {r#"
3053                --- a/root/2.txt
3054                +++ b/root/2.txt
3055                 Hola!
3056                -Como
3057                +Como estas?
3058                 Adios
3059            "#}))
3060            .unwrap();
3061        cx.run_until_parked();
3062
3063        zeta.read_with(cx, |zeta, cx| {
3064            let prediction = zeta
3065                .current_prediction_for_buffer(&buffer1, &project, cx)
3066                .unwrap();
3067            assert_matches!(
3068                prediction,
3069                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3070            );
3071        });
3072
3073        let buffer2 = project
3074            .update(cx, |project, cx| {
3075                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3076                project.open_buffer(path, cx)
3077            })
3078            .await
3079            .unwrap();
3080
3081        zeta.read_with(cx, |zeta, cx| {
3082            let prediction = zeta
3083                .current_prediction_for_buffer(&buffer2, &project, cx)
3084                .unwrap();
3085            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3086        });
3087    }
3088
3089    #[gpui::test]
3090    async fn test_simple_request(cx: &mut TestAppContext) {
3091        let (zeta, mut requests) = init_test(cx);
3092        let fs = FakeFs::new(cx.executor());
3093        fs.insert_tree(
3094            "/root",
3095            json!({
3096                "foo.md":  "Hello!\nHow\nBye\n"
3097            }),
3098        )
3099        .await;
3100        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3101
3102        let buffer = project
3103            .update(cx, |project, cx| {
3104                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3105                project.open_buffer(path, cx)
3106            })
3107            .await
3108            .unwrap();
3109        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3110        let position = snapshot.anchor_before(language::Point::new(1, 3));
3111
3112        let prediction_task = zeta.update(cx, |zeta, cx| {
3113            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3114        });
3115
3116        let (_, respond_tx) = requests.predict.next().await.unwrap();
3117
3118        // TODO Put back when we have a structured request again
3119        // assert_eq!(
3120        //     request.excerpt_path.as_ref(),
3121        //     Path::new(path!("root/foo.md"))
3122        // );
3123        // assert_eq!(
3124        //     request.cursor_point,
3125        //     Point {
3126        //         line: Line(1),
3127        //         column: 3
3128        //     }
3129        // );
3130
3131        respond_tx
3132            .send(model_response(indoc! { r"
3133                --- a/root/foo.md
3134                +++ b/root/foo.md
3135                @@ ... @@
3136                 Hello!
3137                -How
3138                +How are you?
3139                 Bye
3140            "}))
3141            .unwrap();
3142
3143        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3144
3145        assert_eq!(prediction.edits.len(), 1);
3146        assert_eq!(
3147            prediction.edits[0].0.to_point(&snapshot).start,
3148            language::Point::new(1, 3)
3149        );
3150        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3151    }
3152
3153    #[gpui::test]
3154    async fn test_request_events(cx: &mut TestAppContext) {
3155        let (zeta, mut requests) = init_test(cx);
3156        let fs = FakeFs::new(cx.executor());
3157        fs.insert_tree(
3158            "/root",
3159            json!({
3160                "foo.md": "Hello!\n\nBye\n"
3161            }),
3162        )
3163        .await;
3164        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3165
3166        let buffer = project
3167            .update(cx, |project, cx| {
3168                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3169                project.open_buffer(path, cx)
3170            })
3171            .await
3172            .unwrap();
3173
3174        zeta.update(cx, |zeta, cx| {
3175            zeta.register_buffer(&buffer, &project, cx);
3176        });
3177
3178        buffer.update(cx, |buffer, cx| {
3179            buffer.edit(vec![(7..7, "How")], None, cx);
3180        });
3181
3182        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3183        let position = snapshot.anchor_before(language::Point::new(1, 3));
3184
3185        let prediction_task = zeta.update(cx, |zeta, cx| {
3186            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3187        });
3188
3189        let (request, respond_tx) = requests.predict.next().await.unwrap();
3190
3191        let prompt = prompt_from_request(&request);
3192        assert!(
3193            prompt.contains(indoc! {"
3194            --- a/root/foo.md
3195            +++ b/root/foo.md
3196            @@ -1,3 +1,3 @@
3197             Hello!
3198            -
3199            +How
3200             Bye
3201        "}),
3202            "{prompt}"
3203        );
3204
3205        respond_tx
3206            .send(model_response(indoc! {r#"
3207                --- a/root/foo.md
3208                +++ b/root/foo.md
3209                @@ ... @@
3210                 Hello!
3211                -How
3212                +How are you?
3213                 Bye
3214            "#}))
3215            .unwrap();
3216
3217        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3218
3219        assert_eq!(prediction.edits.len(), 1);
3220        assert_eq!(
3221            prediction.edits[0].0.to_point(&snapshot).start,
3222            language::Point::new(1, 3)
3223        );
3224        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3225    }
3226
3227    #[gpui::test]
3228    async fn test_empty_prediction(cx: &mut TestAppContext) {
3229        let (zeta, mut requests) = init_test(cx);
3230        let fs = FakeFs::new(cx.executor());
3231        fs.insert_tree(
3232            "/root",
3233            json!({
3234                "foo.md":  "Hello!\nHow\nBye\n"
3235            }),
3236        )
3237        .await;
3238        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3239
3240        let buffer = project
3241            .update(cx, |project, cx| {
3242                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3243                project.open_buffer(path, cx)
3244            })
3245            .await
3246            .unwrap();
3247        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3248        let position = snapshot.anchor_before(language::Point::new(1, 3));
3249
3250        zeta.update(cx, |zeta, cx| {
3251            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3252        });
3253
3254        const NO_OP_DIFF: &str = indoc! { r"
3255            --- a/root/foo.md
3256            +++ b/root/foo.md
3257            @@ ... @@
3258             Hello!
3259            -How
3260            +How
3261             Bye
3262        "};
3263
3264        let (_, respond_tx) = requests.predict.next().await.unwrap();
3265        let response = model_response(NO_OP_DIFF);
3266        let id = response.id.clone();
3267        respond_tx.send(response).unwrap();
3268
3269        cx.run_until_parked();
3270
3271        zeta.read_with(cx, |zeta, cx| {
3272            assert!(
3273                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3274                    .is_none()
3275            );
3276        });
3277
3278        // prediction is reported as rejected
3279        let (reject_request, _) = requests.reject.next().await.unwrap();
3280
3281        assert_eq!(
3282            &reject_request.rejections,
3283            &[EditPredictionRejection {
3284                request_id: id,
3285                reason: EditPredictionRejectReason::Empty,
3286                was_shown: false
3287            }]
3288        );
3289    }
3290
3291    #[gpui::test]
3292    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3293        let (zeta, mut requests) = init_test(cx);
3294        let fs = FakeFs::new(cx.executor());
3295        fs.insert_tree(
3296            "/root",
3297            json!({
3298                "foo.md":  "Hello!\nHow\nBye\n"
3299            }),
3300        )
3301        .await;
3302        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3303
3304        let buffer = project
3305            .update(cx, |project, cx| {
3306                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3307                project.open_buffer(path, cx)
3308            })
3309            .await
3310            .unwrap();
3311        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3312        let position = snapshot.anchor_before(language::Point::new(1, 3));
3313
3314        zeta.update(cx, |zeta, cx| {
3315            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3316        });
3317
3318        let (_, respond_tx) = requests.predict.next().await.unwrap();
3319
3320        buffer.update(cx, |buffer, cx| {
3321            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3322        });
3323
3324        let response = model_response(SIMPLE_DIFF);
3325        let id = response.id.clone();
3326        respond_tx.send(response).unwrap();
3327
3328        cx.run_until_parked();
3329
3330        zeta.read_with(cx, |zeta, cx| {
3331            assert!(
3332                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3333                    .is_none()
3334            );
3335        });
3336
3337        // prediction is reported as rejected
3338        let (reject_request, _) = requests.reject.next().await.unwrap();
3339
3340        assert_eq!(
3341            &reject_request.rejections,
3342            &[EditPredictionRejection {
3343                request_id: id,
3344                reason: EditPredictionRejectReason::InterpolatedEmpty,
3345                was_shown: false
3346            }]
3347        );
3348    }
3349
3350    const SIMPLE_DIFF: &str = indoc! { r"
3351        --- a/root/foo.md
3352        +++ b/root/foo.md
3353        @@ ... @@
3354         Hello!
3355        -How
3356        +How are you?
3357         Bye
3358    "};
3359
3360    #[gpui::test]
3361    async fn test_replace_current(cx: &mut TestAppContext) {
3362        let (zeta, mut requests) = init_test(cx);
3363        let fs = FakeFs::new(cx.executor());
3364        fs.insert_tree(
3365            "/root",
3366            json!({
3367                "foo.md":  "Hello!\nHow\nBye\n"
3368            }),
3369        )
3370        .await;
3371        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3372
3373        let buffer = project
3374            .update(cx, |project, cx| {
3375                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3376                project.open_buffer(path, cx)
3377            })
3378            .await
3379            .unwrap();
3380        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3381        let position = snapshot.anchor_before(language::Point::new(1, 3));
3382
3383        zeta.update(cx, |zeta, cx| {
3384            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3385        });
3386
3387        let (_, respond_tx) = requests.predict.next().await.unwrap();
3388        let first_response = model_response(SIMPLE_DIFF);
3389        let first_id = first_response.id.clone();
3390        respond_tx.send(first_response).unwrap();
3391
3392        cx.run_until_parked();
3393
3394        zeta.read_with(cx, |zeta, cx| {
3395            assert_eq!(
3396                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3397                    .unwrap()
3398                    .id
3399                    .0,
3400                first_id
3401            );
3402        });
3403
3404        // a second request is triggered
3405        zeta.update(cx, |zeta, cx| {
3406            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3407        });
3408
3409        let (_, respond_tx) = requests.predict.next().await.unwrap();
3410        let second_response = model_response(SIMPLE_DIFF);
3411        let second_id = second_response.id.clone();
3412        respond_tx.send(second_response).unwrap();
3413
3414        cx.run_until_parked();
3415
3416        zeta.read_with(cx, |zeta, cx| {
3417            // second replaces first
3418            assert_eq!(
3419                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3420                    .unwrap()
3421                    .id
3422                    .0,
3423                second_id
3424            );
3425        });
3426
3427        // first is reported as replaced
3428        let (reject_request, _) = requests.reject.next().await.unwrap();
3429
3430        assert_eq!(
3431            &reject_request.rejections,
3432            &[EditPredictionRejection {
3433                request_id: first_id,
3434                reason: EditPredictionRejectReason::Replaced,
3435                was_shown: false
3436            }]
3437        );
3438    }
3439
3440    #[gpui::test]
3441    async fn test_current_preferred(cx: &mut TestAppContext) {
3442        let (zeta, mut requests) = init_test(cx);
3443        let fs = FakeFs::new(cx.executor());
3444        fs.insert_tree(
3445            "/root",
3446            json!({
3447                "foo.md":  "Hello!\nHow\nBye\n"
3448            }),
3449        )
3450        .await;
3451        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3452
3453        let buffer = project
3454            .update(cx, |project, cx| {
3455                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3456                project.open_buffer(path, cx)
3457            })
3458            .await
3459            .unwrap();
3460        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3461        let position = snapshot.anchor_before(language::Point::new(1, 3));
3462
3463        zeta.update(cx, |zeta, cx| {
3464            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3465        });
3466
3467        let (_, respond_tx) = requests.predict.next().await.unwrap();
3468        let first_response = model_response(SIMPLE_DIFF);
3469        let first_id = first_response.id.clone();
3470        respond_tx.send(first_response).unwrap();
3471
3472        cx.run_until_parked();
3473
3474        zeta.read_with(cx, |zeta, cx| {
3475            assert_eq!(
3476                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3477                    .unwrap()
3478                    .id
3479                    .0,
3480                first_id
3481            );
3482        });
3483
3484        // a second request is triggered
3485        zeta.update(cx, |zeta, cx| {
3486            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3487        });
3488
3489        let (_, respond_tx) = requests.predict.next().await.unwrap();
3490        // worse than current prediction
3491        let second_response = model_response(indoc! { r"
3492            --- a/root/foo.md
3493            +++ b/root/foo.md
3494            @@ ... @@
3495             Hello!
3496            -How
3497            +How are
3498             Bye
3499        "});
3500        let second_id = second_response.id.clone();
3501        respond_tx.send(second_response).unwrap();
3502
3503        cx.run_until_parked();
3504
3505        zeta.read_with(cx, |zeta, cx| {
3506            // first is preferred over second
3507            assert_eq!(
3508                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3509                    .unwrap()
3510                    .id
3511                    .0,
3512                first_id
3513            );
3514        });
3515
3516        // second is reported as rejected
3517        let (reject_request, _) = requests.reject.next().await.unwrap();
3518
3519        assert_eq!(
3520            &reject_request.rejections,
3521            &[EditPredictionRejection {
3522                request_id: second_id,
3523                reason: EditPredictionRejectReason::CurrentPreferred,
3524                was_shown: false
3525            }]
3526        );
3527    }
3528
3529    #[gpui::test]
3530    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3531        let (zeta, mut requests) = init_test(cx);
3532        let fs = FakeFs::new(cx.executor());
3533        fs.insert_tree(
3534            "/root",
3535            json!({
3536                "foo.md":  "Hello!\nHow\nBye\n"
3537            }),
3538        )
3539        .await;
3540        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3541
3542        let buffer = project
3543            .update(cx, |project, cx| {
3544                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3545                project.open_buffer(path, cx)
3546            })
3547            .await
3548            .unwrap();
3549        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3550        let position = snapshot.anchor_before(language::Point::new(1, 3));
3551
3552        // start two refresh tasks
3553        zeta.update(cx, |zeta, cx| {
3554            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3555        });
3556
3557        let (_, respond_first) = requests.predict.next().await.unwrap();
3558
3559        zeta.update(cx, |zeta, cx| {
3560            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3561        });
3562
3563        let (_, respond_second) = requests.predict.next().await.unwrap();
3564
3565        // wait for throttle
3566        cx.run_until_parked();
3567
3568        // second responds first
3569        let second_response = model_response(SIMPLE_DIFF);
3570        let second_id = second_response.id.clone();
3571        respond_second.send(second_response).unwrap();
3572
3573        cx.run_until_parked();
3574
3575        zeta.read_with(cx, |zeta, cx| {
3576            // current prediction is second
3577            assert_eq!(
3578                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3579                    .unwrap()
3580                    .id
3581                    .0,
3582                second_id
3583            );
3584        });
3585
3586        let first_response = model_response(SIMPLE_DIFF);
3587        let first_id = first_response.id.clone();
3588        respond_first.send(first_response).unwrap();
3589
3590        cx.run_until_parked();
3591
3592        zeta.read_with(cx, |zeta, cx| {
3593            // current prediction is still second, since first was cancelled
3594            assert_eq!(
3595                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3596                    .unwrap()
3597                    .id
3598                    .0,
3599                second_id
3600            );
3601        });
3602
3603        // first is reported as rejected
3604        let (reject_request, _) = requests.reject.next().await.unwrap();
3605
3606        cx.run_until_parked();
3607
3608        assert_eq!(
3609            &reject_request.rejections,
3610            &[EditPredictionRejection {
3611                request_id: first_id,
3612                reason: EditPredictionRejectReason::Canceled,
3613                was_shown: false
3614            }]
3615        );
3616    }
3617
3618    #[gpui::test]
3619    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3620        let (zeta, mut requests) = init_test(cx);
3621        let fs = FakeFs::new(cx.executor());
3622        fs.insert_tree(
3623            "/root",
3624            json!({
3625                "foo.md":  "Hello!\nHow\nBye\n"
3626            }),
3627        )
3628        .await;
3629        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3630
3631        let buffer = project
3632            .update(cx, |project, cx| {
3633                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3634                project.open_buffer(path, cx)
3635            })
3636            .await
3637            .unwrap();
3638        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3639        let position = snapshot.anchor_before(language::Point::new(1, 3));
3640
3641        // start two refresh tasks
3642        zeta.update(cx, |zeta, cx| {
3643            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3644        });
3645
3646        let (_, respond_first) = requests.predict.next().await.unwrap();
3647
3648        zeta.update(cx, |zeta, cx| {
3649            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3650        });
3651
3652        let (_, respond_second) = requests.predict.next().await.unwrap();
3653
3654        // wait for throttle, so requests are sent
3655        cx.run_until_parked();
3656
3657        zeta.update(cx, |zeta, cx| {
3658            // start a third request
3659            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3660
3661            // 2 are pending, so 2nd is cancelled
3662            assert_eq!(
3663                zeta.get_or_init_zeta_project(&project, cx)
3664                    .cancelled_predictions
3665                    .iter()
3666                    .copied()
3667                    .collect::<Vec<_>>(),
3668                [1]
3669            );
3670        });
3671
3672        // wait for throttle
3673        cx.run_until_parked();
3674
3675        let (_, respond_third) = requests.predict.next().await.unwrap();
3676
3677        let first_response = model_response(SIMPLE_DIFF);
3678        let first_id = first_response.id.clone();
3679        respond_first.send(first_response).unwrap();
3680
3681        cx.run_until_parked();
3682
3683        zeta.read_with(cx, |zeta, cx| {
3684            // current prediction is first
3685            assert_eq!(
3686                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3687                    .unwrap()
3688                    .id
3689                    .0,
3690                first_id
3691            );
3692        });
3693
3694        let cancelled_response = model_response(SIMPLE_DIFF);
3695        let cancelled_id = cancelled_response.id.clone();
3696        respond_second.send(cancelled_response).unwrap();
3697
3698        cx.run_until_parked();
3699
3700        zeta.read_with(cx, |zeta, cx| {
3701            // current prediction is still first, since second was cancelled
3702            assert_eq!(
3703                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3704                    .unwrap()
3705                    .id
3706                    .0,
3707                first_id
3708            );
3709        });
3710
3711        let third_response = model_response(SIMPLE_DIFF);
3712        let third_response_id = third_response.id.clone();
3713        respond_third.send(third_response).unwrap();
3714
3715        cx.run_until_parked();
3716
3717        zeta.read_with(cx, |zeta, cx| {
3718            // third completes and replaces first
3719            assert_eq!(
3720                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3721                    .unwrap()
3722                    .id
3723                    .0,
3724                third_response_id
3725            );
3726        });
3727
3728        // second is reported as rejected
3729        let (reject_request, _) = requests.reject.next().await.unwrap();
3730
3731        cx.run_until_parked();
3732
3733        assert_eq!(
3734            &reject_request.rejections,
3735            &[
3736                EditPredictionRejection {
3737                    request_id: cancelled_id,
3738                    reason: EditPredictionRejectReason::Canceled,
3739                    was_shown: false
3740                },
3741                EditPredictionRejection {
3742                    request_id: first_id,
3743                    reason: EditPredictionRejectReason::Replaced,
3744                    was_shown: false
3745                }
3746            ]
3747        );
3748    }
3749
3750    #[gpui::test]
3751    async fn test_rejections_flushing(cx: &mut TestAppContext) {
3752        let (zeta, mut requests) = init_test(cx);
3753
3754        zeta.update(cx, |zeta, _cx| {
3755            zeta.reject_prediction(
3756                EditPredictionId("test-1".into()),
3757                EditPredictionRejectReason::Discarded,
3758                false,
3759            );
3760            zeta.reject_prediction(
3761                EditPredictionId("test-2".into()),
3762                EditPredictionRejectReason::Canceled,
3763                true,
3764            );
3765        });
3766
3767        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3768        cx.run_until_parked();
3769
3770        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3771        respond_tx.send(()).unwrap();
3772
3773        // batched
3774        assert_eq!(reject_request.rejections.len(), 2);
3775        assert_eq!(
3776            reject_request.rejections[0],
3777            EditPredictionRejection {
3778                request_id: "test-1".to_string(),
3779                reason: EditPredictionRejectReason::Discarded,
3780                was_shown: false
3781            }
3782        );
3783        assert_eq!(
3784            reject_request.rejections[1],
3785            EditPredictionRejection {
3786                request_id: "test-2".to_string(),
3787                reason: EditPredictionRejectReason::Canceled,
3788                was_shown: true
3789            }
3790        );
3791
3792        // Reaching batch size limit sends without debounce
3793        zeta.update(cx, |zeta, _cx| {
3794            for i in 0..70 {
3795                zeta.reject_prediction(
3796                    EditPredictionId(format!("batch-{}", i).into()),
3797                    EditPredictionRejectReason::Discarded,
3798                    false,
3799                );
3800            }
3801        });
3802
3803        // First MAX/2 items are sent immediately
3804        cx.run_until_parked();
3805        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3806        respond_tx.send(()).unwrap();
3807
3808        assert_eq!(reject_request.rejections.len(), 50);
3809        assert_eq!(reject_request.rejections[0].request_id, "batch-0");
3810        assert_eq!(reject_request.rejections[49].request_id, "batch-49");
3811
3812        // Remaining items are debounced with the next batch
3813        cx.executor().advance_clock(Duration::from_secs(15));
3814        cx.run_until_parked();
3815
3816        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3817        respond_tx.send(()).unwrap();
3818
3819        assert_eq!(reject_request.rejections.len(), 20);
3820        assert_eq!(reject_request.rejections[0].request_id, "batch-50");
3821        assert_eq!(reject_request.rejections[19].request_id, "batch-69");
3822
3823        // Request failure
3824        zeta.update(cx, |zeta, _cx| {
3825            zeta.reject_prediction(
3826                EditPredictionId("retry-1".into()),
3827                EditPredictionRejectReason::Discarded,
3828                false,
3829            );
3830        });
3831
3832        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3833        cx.run_until_parked();
3834
3835        let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
3836        assert_eq!(reject_request.rejections.len(), 1);
3837        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3838        // Simulate failure
3839        drop(_respond_tx);
3840
3841        // Add another rejection
3842        zeta.update(cx, |zeta, _cx| {
3843            zeta.reject_prediction(
3844                EditPredictionId("retry-2".into()),
3845                EditPredictionRejectReason::Discarded,
3846                false,
3847            );
3848        });
3849
3850        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3851        cx.run_until_parked();
3852
3853        // Retry should include both the failed item and the new one
3854        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3855        respond_tx.send(()).unwrap();
3856
3857        assert_eq!(reject_request.rejections.len(), 2);
3858        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3859        assert_eq!(reject_request.rejections[1].request_id, "retry-2");
3860    }
3861
3862    // Skipped until we start including diagnostics in prompt
3863    // #[gpui::test]
3864    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3865    //     let (zeta, mut req_rx) = init_test(cx);
3866    //     let fs = FakeFs::new(cx.executor());
3867    //     fs.insert_tree(
3868    //         "/root",
3869    //         json!({
3870    //             "foo.md": "Hello!\nBye"
3871    //         }),
3872    //     )
3873    //     .await;
3874    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3875
3876    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3877    //     let diagnostic = lsp::Diagnostic {
3878    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3879    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3880    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3881    //         ..Default::default()
3882    //     };
3883
3884    //     project.update(cx, |project, cx| {
3885    //         project.lsp_store().update(cx, |lsp_store, cx| {
3886    //             // Create some diagnostics
3887    //             lsp_store
3888    //                 .update_diagnostics(
3889    //                     LanguageServerId(0),
3890    //                     lsp::PublishDiagnosticsParams {
3891    //                         uri: path_to_buffer_uri.clone(),
3892    //                         diagnostics: vec![diagnostic],
3893    //                         version: None,
3894    //                     },
3895    //                     None,
3896    //                     language::DiagnosticSourceKind::Pushed,
3897    //                     &[],
3898    //                     cx,
3899    //                 )
3900    //                 .unwrap();
3901    //         });
3902    //     });
3903
3904    //     let buffer = project
3905    //         .update(cx, |project, cx| {
3906    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3907    //             project.open_buffer(path, cx)
3908    //         })
3909    //         .await
3910    //         .unwrap();
3911
3912    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3913    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3914
3915    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3916    //         zeta.request_prediction(&project, &buffer, position, cx)
3917    //     });
3918
3919    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3920
3921    //     assert_eq!(request.diagnostic_groups.len(), 1);
3922    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3923    //         .unwrap();
3924    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3925    //     assert_eq!(
3926    //         value,
3927    //         json!({
3928    //             "entries": [{
3929    //                 "range": {
3930    //                     "start": 8,
3931    //                     "end": 10
3932    //                 },
3933    //                 "diagnostic": {
3934    //                     "source": null,
3935    //                     "code": null,
3936    //                     "code_description": null,
3937    //                     "severity": 1,
3938    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3939    //                     "markdown": null,
3940    //                     "group_id": 0,
3941    //                     "is_primary": true,
3942    //                     "is_disk_based": false,
3943    //                     "is_unnecessary": false,
3944    //                     "source_kind": "Pushed",
3945    //                     "data": null,
3946    //                     "underline": true
3947    //                 }
3948    //             }],
3949    //             "primary_ix": 0
3950    //         })
3951    //     );
3952    // }
3953
3954    fn model_response(text: &str) -> open_ai::Response {
3955        open_ai::Response {
3956            id: Uuid::new_v4().to_string(),
3957            object: "response".into(),
3958            created: 0,
3959            model: "model".into(),
3960            choices: vec![open_ai::Choice {
3961                index: 0,
3962                message: open_ai::RequestMessage::Assistant {
3963                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3964                    tool_calls: vec![],
3965                },
3966                finish_reason: None,
3967            }],
3968            usage: Usage {
3969                prompt_tokens: 0,
3970                completion_tokens: 0,
3971                total_tokens: 0,
3972            },
3973        }
3974    }
3975
3976    fn prompt_from_request(request: &open_ai::Request) -> &str {
3977        assert_eq!(request.messages.len(), 1);
3978        let open_ai::RequestMessage::User {
3979            content: open_ai::MessageContent::Plain(content),
3980            ..
3981        } = &request.messages[0]
3982        else {
3983            panic!(
3984                "Request does not have single user message of type Plain. {:#?}",
3985                request
3986            );
3987        };
3988        content
3989    }
3990
3991    struct RequestChannels {
3992        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3993        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3994    }
3995
3996    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3997        cx.update(move |cx| {
3998            let settings_store = SettingsStore::test(cx);
3999            cx.set_global(settings_store);
4000            zlog::init_test();
4001
4002            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
4003            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
4004
4005            let http_client = FakeHttpClient::create({
4006                move |req| {
4007                    let uri = req.uri().path().to_string();
4008                    let mut body = req.into_body();
4009                    let predict_req_tx = predict_req_tx.clone();
4010                    let reject_req_tx = reject_req_tx.clone();
4011                    async move {
4012                        let resp = match uri.as_str() {
4013                            "/client/llm_tokens" => serde_json::to_string(&json!({
4014                                "token": "test"
4015                            }))
4016                            .unwrap(),
4017                            "/predict_edits/raw" => {
4018                                let mut buf = Vec::new();
4019                                body.read_to_end(&mut buf).await.ok();
4020                                let req = serde_json::from_slice(&buf).unwrap();
4021
4022                                let (res_tx, res_rx) = oneshot::channel();
4023                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
4024                                serde_json::to_string(&res_rx.await?).unwrap()
4025                            }
4026                            "/predict_edits/reject" => {
4027                                let mut buf = Vec::new();
4028                                body.read_to_end(&mut buf).await.ok();
4029                                let req = serde_json::from_slice(&buf).unwrap();
4030
4031                                let (res_tx, res_rx) = oneshot::channel();
4032                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
4033                                serde_json::to_string(&res_rx.await?).unwrap()
4034                            }
4035                            _ => {
4036                                panic!("Unexpected path: {}", uri)
4037                            }
4038                        };
4039
4040                        Ok(Response::builder().body(resp.into()).unwrap())
4041                    }
4042                }
4043            });
4044
4045            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
4046            client.cloud_client().set_credentials(1, "test".into());
4047
4048            language_model::init(client.clone(), cx);
4049
4050            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
4051            let zeta = Zeta::global(&client, &user_store, cx);
4052
4053            (
4054                zeta,
4055                RequestChannels {
4056                    predict: predict_req_rx,
4057                    reject: reject_req_rx,
4058                },
4059            )
4060        })
4061    }
4062}