zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::{HashMap, HashSet};
  13use command_palette_hooks::CommandPaletteFilter;
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{
  16    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  17    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  18    SyntaxIndex, SyntaxIndexState,
  19};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  21use futures::channel::mpsc::UnboundedReceiver;
  22use futures::channel::{mpsc, oneshot};
  23use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
  24use gpui::{
  25    App, AsyncApp, BackgroundExecutor, Entity, EntityId, Global, SharedString, Subscription, Task,
  26    WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::{
  31    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  32};
  33use language::{BufferSnapshot, OffsetRangeExt};
  34use language_model::{LlmApiToken, RefreshLlmTokenListener};
  35use open_ai::FunctionDefinition;
  36use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  37use release_channel::AppVersion;
  38use semver::Version;
  39use serde::de::DeserializeOwned;
  40use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  41use std::any::{Any as _, TypeId};
  42use std::collections::{VecDeque, hash_map};
  43use telemetry_events::EditPredictionRating;
  44use workspace::Workspace;
  45
  46use std::ops::Range;
  47use std::path::Path;
  48use std::rc::Rc;
  49use std::str::FromStr as _;
  50use std::sync::{Arc, LazyLock};
  51use std::time::{Duration, Instant};
  52use std::{env, mem};
  53use thiserror::Error;
  54use util::rel_path::RelPathBuf;
  55use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod assemble_excerpts;
  59mod license_detection;
  60mod onboarding_modal;
  61mod prediction;
  62mod provider;
  63mod rate_prediction_modal;
  64pub mod retrieval_search;
  65mod sweep_ai;
  66pub mod udiff;
  67mod xml_edits;
  68pub mod zeta1;
  69
  70#[cfg(test)]
  71mod zeta_tests;
  72
  73use crate::assemble_excerpts::assemble_excerpts;
  74use crate::license_detection::LicenseDetectionWatcher;
  75use crate::onboarding_modal::ZedPredictModal;
  76pub use crate::prediction::EditPrediction;
  77pub use crate::prediction::EditPredictionId;
  78pub use crate::prediction::EditPredictionInputs;
  79use crate::prediction::EditPredictionResult;
  80use crate::rate_prediction_modal::{
  81    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  82    ThumbsUpActivePrediction,
  83};
  84use crate::sweep_ai::SweepAi;
  85use crate::zeta1::request_prediction_with_zeta1;
  86pub use provider::ZetaEditPredictionProvider;
  87
  88actions!(
  89    edit_prediction,
  90    [
  91        /// Resets the edit prediction onboarding state.
  92        ResetOnboarding,
  93        /// Opens the rate completions modal.
  94        RateCompletions,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 105
 106pub struct SweepFeatureFlag;
 107
 108impl FeatureFlag for SweepFeatureFlag {
 109    const NAME: &str = "sweep-ai";
 110}
 111pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 112    max_bytes: 512,
 113    min_bytes: 128,
 114    target_before_cursor_over_total_bytes: 0.5,
 115};
 116
 117pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 118    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 119
 120pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 121    excerpt: DEFAULT_EXCERPT_OPTIONS,
 122};
 123
 124pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 125    EditPredictionContextOptions {
 126        use_imports: true,
 127        max_retrieved_declarations: 0,
 128        excerpt: DEFAULT_EXCERPT_OPTIONS,
 129        score: EditPredictionScoreOptions {
 130            omit_excerpt_overlaps: true,
 131        },
 132    };
 133
 134pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 135    context: DEFAULT_CONTEXT_OPTIONS,
 136    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 137    max_diagnostic_bytes: 2048,
 138    prompt_format: PromptFormat::DEFAULT,
 139    file_indexing_parallelism: 1,
 140    buffer_change_grouping_interval: Duration::from_secs(1),
 141};
 142
 143static USE_OLLAMA: LazyLock<bool> =
 144    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 145static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 146    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 147        "qwen3-coder:30b".to_string()
 148    } else {
 149        "yqvev8r3".to_string()
 150    })
 151});
 152static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 153    match env::var("ZED_ZETA2_MODEL").as_deref() {
 154        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 155        Ok(model) => model,
 156        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 157        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 158    }
 159    .to_string()
 160});
 161static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 162    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 163        if *USE_OLLAMA {
 164            Some("http://localhost:11434/v1/chat/completions".into())
 165        } else {
 166            None
 167        }
 168    })
 169});
 170
 171pub struct Zeta2FeatureFlag;
 172
 173impl FeatureFlag for Zeta2FeatureFlag {
 174    const NAME: &'static str = "zeta2";
 175
 176    fn enabled_for_staff() -> bool {
 177        true
 178    }
 179}
 180
 181#[derive(Clone)]
 182struct ZetaGlobal(Entity<Zeta>);
 183
 184impl Global for ZetaGlobal {}
 185
 186pub struct Zeta {
 187    client: Arc<Client>,
 188    user_store: Entity<UserStore>,
 189    llm_token: LlmApiToken,
 190    _llm_token_subscription: Subscription,
 191    projects: HashMap<EntityId, ZetaProject>,
 192    options: ZetaOptions,
 193    update_required: bool,
 194    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 195    #[cfg(feature = "eval-support")]
 196    eval_cache: Option<Arc<dyn EvalCache>>,
 197    edit_prediction_model: ZetaEditPredictionModel,
 198    sweep_ai: SweepAi,
 199    data_collection_choice: DataCollectionChoice,
 200    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 201    shown_predictions: VecDeque<EditPrediction>,
 202    rated_predictions: HashSet<EditPredictionId>,
 203}
 204
 205#[derive(Copy, Clone, Default, PartialEq, Eq)]
 206pub enum ZetaEditPredictionModel {
 207    #[default]
 208    Zeta1,
 209    Zeta2,
 210    Sweep,
 211}
 212
 213#[derive(Debug, Clone, PartialEq)]
 214pub struct ZetaOptions {
 215    pub context: ContextMode,
 216    pub max_prompt_bytes: usize,
 217    pub max_diagnostic_bytes: usize,
 218    pub prompt_format: predict_edits_v3::PromptFormat,
 219    pub file_indexing_parallelism: usize,
 220    pub buffer_change_grouping_interval: Duration,
 221}
 222
 223#[derive(Debug, Clone, PartialEq)]
 224pub enum ContextMode {
 225    Agentic(AgenticContextOptions),
 226    Syntax(EditPredictionContextOptions),
 227}
 228
 229#[derive(Debug, Clone, PartialEq)]
 230pub struct AgenticContextOptions {
 231    pub excerpt: EditPredictionExcerptOptions,
 232}
 233
 234impl ContextMode {
 235    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 236        match self {
 237            ContextMode::Agentic(options) => &options.excerpt,
 238            ContextMode::Syntax(options) => &options.excerpt,
 239        }
 240    }
 241}
 242
 243#[derive(Debug)]
 244pub enum ZetaDebugInfo {
 245    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 246    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 247    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 248    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 249    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 250}
 251
 252#[derive(Debug)]
 253pub struct ZetaContextRetrievalStartedDebugInfo {
 254    pub project: Entity<Project>,
 255    pub timestamp: Instant,
 256    pub search_prompt: String,
 257}
 258
 259#[derive(Debug)]
 260pub struct ZetaContextRetrievalDebugInfo {
 261    pub project: Entity<Project>,
 262    pub timestamp: Instant,
 263}
 264
 265#[derive(Debug)]
 266pub struct ZetaEditPredictionDebugInfo {
 267    pub inputs: EditPredictionInputs,
 268    pub retrieval_time: Duration,
 269    pub buffer: WeakEntity<Buffer>,
 270    pub position: language::Anchor,
 271    pub local_prompt: Result<String, String>,
 272    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 273}
 274
 275#[derive(Debug)]
 276pub struct ZetaSearchQueryDebugInfo {
 277    pub project: Entity<Project>,
 278    pub timestamp: Instant,
 279    pub search_queries: Vec<SearchToolQuery>,
 280}
 281
 282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 283
 284struct ZetaProject {
 285    syntax_index: Option<Entity<SyntaxIndex>>,
 286    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 287    last_event: Option<LastEvent>,
 288    recent_paths: VecDeque<ProjectPath>,
 289    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 290    current_prediction: Option<CurrentEditPrediction>,
 291    next_pending_prediction_id: usize,
 292    pending_predictions: ArrayVec<PendingPrediction, 2>,
 293    last_prediction_refresh: Option<(EntityId, Instant)>,
 294    cancelled_predictions: HashSet<usize>,
 295    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 296    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 297    refresh_context_debounce_task: Option<Task<Option<()>>>,
 298    refresh_context_timestamp: Option<Instant>,
 299    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 300    _subscription: gpui::Subscription,
 301}
 302
 303impl ZetaProject {
 304    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 305        self.events
 306            .iter()
 307            .cloned()
 308            .chain(
 309                self.last_event
 310                    .as_ref()
 311                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 312            )
 313            .collect()
 314    }
 315
 316    fn cancel_pending_prediction(
 317        &mut self,
 318        pending_prediction: PendingPrediction,
 319        cx: &mut Context<Zeta>,
 320    ) {
 321        self.cancelled_predictions.insert(pending_prediction.id);
 322
 323        cx.spawn(async move |this, cx| {
 324            let Some(prediction_id) = pending_prediction.task.await else {
 325                return;
 326            };
 327
 328            this.update(cx, |this, _cx| {
 329                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 330            })
 331            .ok();
 332        })
 333        .detach()
 334    }
 335}
 336
 337#[derive(Debug, Clone)]
 338struct CurrentEditPrediction {
 339    pub requested_by: PredictionRequestedBy,
 340    pub prediction: EditPrediction,
 341    pub was_shown: bool,
 342}
 343
 344impl CurrentEditPrediction {
 345    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 346        let Some(new_edits) = self
 347            .prediction
 348            .interpolate(&self.prediction.buffer.read(cx))
 349        else {
 350            return false;
 351        };
 352
 353        if self.prediction.buffer != old_prediction.prediction.buffer {
 354            return true;
 355        }
 356
 357        let Some(old_edits) = old_prediction
 358            .prediction
 359            .interpolate(&old_prediction.prediction.buffer.read(cx))
 360        else {
 361            return true;
 362        };
 363
 364        let requested_by_buffer_id = self.requested_by.buffer_id();
 365
 366        // This reduces the occurrence of UI thrash from replacing edits
 367        //
 368        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 369        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 370            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 371            && old_edits.len() == 1
 372            && new_edits.len() == 1
 373        {
 374            let (old_range, old_text) = &old_edits[0];
 375            let (new_range, new_text) = &new_edits[0];
 376            new_range == old_range && new_text.starts_with(old_text.as_ref())
 377        } else {
 378            true
 379        }
 380    }
 381}
 382
 383#[derive(Debug, Clone)]
 384enum PredictionRequestedBy {
 385    DiagnosticsUpdate,
 386    Buffer(EntityId),
 387}
 388
 389impl PredictionRequestedBy {
 390    pub fn buffer_id(&self) -> Option<EntityId> {
 391        match self {
 392            PredictionRequestedBy::DiagnosticsUpdate => None,
 393            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 394        }
 395    }
 396}
 397
 398#[derive(Debug)]
 399struct PendingPrediction {
 400    id: usize,
 401    task: Task<Option<EditPredictionId>>,
 402}
 403
 404/// A prediction from the perspective of a buffer.
 405#[derive(Debug)]
 406enum BufferEditPrediction<'a> {
 407    Local { prediction: &'a EditPrediction },
 408    Jump { prediction: &'a EditPrediction },
 409}
 410
 411#[cfg(test)]
 412impl std::ops::Deref for BufferEditPrediction<'_> {
 413    type Target = EditPrediction;
 414
 415    fn deref(&self) -> &Self::Target {
 416        match self {
 417            BufferEditPrediction::Local { prediction } => prediction,
 418            BufferEditPrediction::Jump { prediction } => prediction,
 419        }
 420    }
 421}
 422
 423struct RegisteredBuffer {
 424    snapshot: BufferSnapshot,
 425    _subscriptions: [gpui::Subscription; 2],
 426}
 427
 428struct LastEvent {
 429    old_snapshot: BufferSnapshot,
 430    new_snapshot: BufferSnapshot,
 431    end_edit_anchor: Option<Anchor>,
 432}
 433
 434impl LastEvent {
 435    pub fn finalize(
 436        &self,
 437        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 438        cx: &App,
 439    ) -> Option<Arc<predict_edits_v3::Event>> {
 440        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 441        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 442
 443        let file = self.new_snapshot.file();
 444        let old_file = self.old_snapshot.file();
 445
 446        let in_open_source_repo = [file, old_file].iter().all(|file| {
 447            file.is_some_and(|file| {
 448                license_detection_watchers
 449                    .get(&file.worktree_id(cx))
 450                    .is_some_and(|watcher| watcher.is_project_open_source())
 451            })
 452        });
 453
 454        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 455
 456        if path == old_path && diff.is_empty() {
 457            None
 458        } else {
 459            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 460                old_path,
 461                path,
 462                diff,
 463                in_open_source_repo,
 464                // TODO: Actually detect if this edit was predicted or not
 465                predicted: false,
 466            }))
 467        }
 468    }
 469}
 470
 471fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 472    if let Some(file) = snapshot.file() {
 473        file.full_path(cx).into()
 474    } else {
 475        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 476    }
 477}
 478
 479impl Zeta {
 480    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 481        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 482    }
 483
 484    pub fn global(
 485        client: &Arc<Client>,
 486        user_store: &Entity<UserStore>,
 487        cx: &mut App,
 488    ) -> Entity<Self> {
 489        cx.try_global::<ZetaGlobal>()
 490            .map(|global| global.0.clone())
 491            .unwrap_or_else(|| {
 492                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 493                cx.set_global(ZetaGlobal(zeta.clone()));
 494                zeta
 495            })
 496    }
 497
 498    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 499        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 500        let data_collection_choice = Self::load_data_collection_choice();
 501
 502        let llm_token = LlmApiToken::default();
 503
 504        let (reject_tx, reject_rx) = mpsc::unbounded();
 505        cx.background_spawn({
 506            let client = client.clone();
 507            let llm_token = llm_token.clone();
 508            let app_version = AppVersion::global(cx);
 509            let background_executor = cx.background_executor().clone();
 510            async move {
 511                Self::handle_rejected_predictions(
 512                    reject_rx,
 513                    client,
 514                    llm_token,
 515                    app_version,
 516                    background_executor,
 517                )
 518                .await
 519            }
 520        })
 521        .detach();
 522
 523        Self {
 524            projects: HashMap::default(),
 525            client,
 526            user_store,
 527            options: DEFAULT_OPTIONS,
 528            llm_token,
 529            _llm_token_subscription: cx.subscribe(
 530                &refresh_llm_token_listener,
 531                |this, _listener, _event, cx| {
 532                    let client = this.client.clone();
 533                    let llm_token = this.llm_token.clone();
 534                    cx.spawn(async move |_this, _cx| {
 535                        llm_token.refresh(&client).await?;
 536                        anyhow::Ok(())
 537                    })
 538                    .detach_and_log_err(cx);
 539                },
 540            ),
 541            update_required: false,
 542            debug_tx: None,
 543            #[cfg(feature = "eval-support")]
 544            eval_cache: None,
 545            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 546            sweep_ai: SweepAi::new(cx),
 547            data_collection_choice,
 548            reject_predictions_tx: reject_tx,
 549            rated_predictions: Default::default(),
 550            shown_predictions: Default::default(),
 551        }
 552    }
 553
 554    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 555        self.edit_prediction_model = model;
 556    }
 557
 558    pub fn has_sweep_api_token(&self) -> bool {
 559        self.sweep_ai.api_token.is_some()
 560    }
 561
 562    #[cfg(feature = "eval-support")]
 563    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 564        self.eval_cache = Some(cache);
 565    }
 566
 567    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 568        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 569        self.debug_tx = Some(debug_watch_tx);
 570        debug_watch_rx
 571    }
 572
 573    pub fn options(&self) -> &ZetaOptions {
 574        &self.options
 575    }
 576
 577    pub fn set_options(&mut self, options: ZetaOptions) {
 578        self.options = options;
 579    }
 580
 581    pub fn clear_history(&mut self) {
 582        for zeta_project in self.projects.values_mut() {
 583            zeta_project.events.clear();
 584        }
 585    }
 586
 587    pub fn context_for_project(
 588        &self,
 589        project: &Entity<Project>,
 590    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 591        self.projects
 592            .get(&project.entity_id())
 593            .and_then(|project| {
 594                Some(
 595                    project
 596                        .context
 597                        .as_ref()?
 598                        .iter()
 599                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 600                )
 601            })
 602            .into_iter()
 603            .flatten()
 604    }
 605
 606    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 607        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 608            self.user_store.read(cx).edit_prediction_usage()
 609        } else {
 610            None
 611        }
 612    }
 613
 614    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 615        self.get_or_init_zeta_project(project, cx);
 616    }
 617
 618    pub fn register_buffer(
 619        &mut self,
 620        buffer: &Entity<Buffer>,
 621        project: &Entity<Project>,
 622        cx: &mut Context<Self>,
 623    ) {
 624        let zeta_project = self.get_or_init_zeta_project(project, cx);
 625        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 626    }
 627
 628    fn get_or_init_zeta_project(
 629        &mut self,
 630        project: &Entity<Project>,
 631        cx: &mut Context<Self>,
 632    ) -> &mut ZetaProject {
 633        self.projects
 634            .entry(project.entity_id())
 635            .or_insert_with(|| ZetaProject {
 636                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 637                    Some(cx.new(|cx| {
 638                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 639                    }))
 640                } else {
 641                    None
 642                },
 643                events: VecDeque::new(),
 644                last_event: None,
 645                recent_paths: VecDeque::new(),
 646                registered_buffers: HashMap::default(),
 647                current_prediction: None,
 648                cancelled_predictions: HashSet::default(),
 649                pending_predictions: ArrayVec::new(),
 650                next_pending_prediction_id: 0,
 651                last_prediction_refresh: None,
 652                context: None,
 653                refresh_context_task: None,
 654                refresh_context_debounce_task: None,
 655                refresh_context_timestamp: None,
 656                license_detection_watchers: HashMap::default(),
 657                _subscription: cx.subscribe(&project, Self::handle_project_event),
 658            })
 659    }
 660
 661    fn handle_project_event(
 662        &mut self,
 663        project: Entity<Project>,
 664        event: &project::Event,
 665        cx: &mut Context<Self>,
 666    ) {
 667        // TODO [zeta2] init with recent paths
 668        match event {
 669            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 670                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 671                    return;
 672                };
 673                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 674                if let Some(path) = path {
 675                    if let Some(ix) = zeta_project
 676                        .recent_paths
 677                        .iter()
 678                        .position(|probe| probe == &path)
 679                    {
 680                        zeta_project.recent_paths.remove(ix);
 681                    }
 682                    zeta_project.recent_paths.push_front(path);
 683                }
 684            }
 685            project::Event::DiagnosticsUpdated { .. } => {
 686                if cx.has_flag::<Zeta2FeatureFlag>() {
 687                    self.refresh_prediction_from_diagnostics(project, cx);
 688                }
 689            }
 690            _ => (),
 691        }
 692    }
 693
 694    fn register_buffer_impl<'a>(
 695        zeta_project: &'a mut ZetaProject,
 696        buffer: &Entity<Buffer>,
 697        project: &Entity<Project>,
 698        cx: &mut Context<Self>,
 699    ) -> &'a mut RegisteredBuffer {
 700        let buffer_id = buffer.entity_id();
 701
 702        if let Some(file) = buffer.read(cx).file() {
 703            let worktree_id = file.worktree_id(cx);
 704            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 705                zeta_project
 706                    .license_detection_watchers
 707                    .entry(worktree_id)
 708                    .or_insert_with(|| {
 709                        let project_entity_id = project.entity_id();
 710                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 711                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 712                            else {
 713                                return;
 714                            };
 715                            zeta_project.license_detection_watchers.remove(&worktree_id);
 716                        })
 717                        .detach();
 718                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 719                    });
 720            }
 721        }
 722
 723        match zeta_project.registered_buffers.entry(buffer_id) {
 724            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 725            hash_map::Entry::Vacant(entry) => {
 726                let snapshot = buffer.read(cx).snapshot();
 727                let project_entity_id = project.entity_id();
 728                entry.insert(RegisteredBuffer {
 729                    snapshot,
 730                    _subscriptions: [
 731                        cx.subscribe(buffer, {
 732                            let project = project.downgrade();
 733                            move |this, buffer, event, cx| {
 734                                if let language::BufferEvent::Edited = event
 735                                    && let Some(project) = project.upgrade()
 736                                {
 737                                    this.report_changes_for_buffer(&buffer, &project, cx);
 738                                }
 739                            }
 740                        }),
 741                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 742                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 743                            else {
 744                                return;
 745                            };
 746                            zeta_project.registered_buffers.remove(&buffer_id);
 747                        }),
 748                    ],
 749                })
 750            }
 751        }
 752    }
 753
 754    fn report_changes_for_buffer(
 755        &mut self,
 756        buffer: &Entity<Buffer>,
 757        project: &Entity<Project>,
 758        cx: &mut Context<Self>,
 759    ) {
 760        let project_state = self.get_or_init_zeta_project(project, cx);
 761        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 762
 763        let new_snapshot = buffer.read(cx).snapshot();
 764        if new_snapshot.version == registered_buffer.snapshot.version {
 765            return;
 766        }
 767
 768        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 769        let end_edit_anchor = new_snapshot
 770            .anchored_edits_since::<Point>(&old_snapshot.version)
 771            .last()
 772            .map(|(_, range)| range.end);
 773        let events = &mut project_state.events;
 774
 775        if let Some(LastEvent {
 776            new_snapshot: last_new_snapshot,
 777            end_edit_anchor: last_end_edit_anchor,
 778            ..
 779        }) = project_state.last_event.as_mut()
 780        {
 781            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 782                == last_new_snapshot.remote_id()
 783                && old_snapshot.version == last_new_snapshot.version;
 784
 785            let should_coalesce = is_next_snapshot_of_same_buffer
 786                && end_edit_anchor
 787                    .as_ref()
 788                    .zip(last_end_edit_anchor.as_ref())
 789                    .is_some_and(|(a, b)| {
 790                        let a = a.to_point(&new_snapshot);
 791                        let b = b.to_point(&new_snapshot);
 792                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 793                    });
 794
 795            if should_coalesce {
 796                *last_end_edit_anchor = end_edit_anchor;
 797                *last_new_snapshot = new_snapshot;
 798                return;
 799            }
 800        }
 801
 802        if events.len() + 1 >= EVENT_COUNT_MAX {
 803            events.pop_front();
 804        }
 805
 806        if let Some(event) = project_state.last_event.take() {
 807            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 808        }
 809
 810        project_state.last_event = Some(LastEvent {
 811            old_snapshot,
 812            new_snapshot,
 813            end_edit_anchor,
 814        });
 815    }
 816
 817    fn current_prediction_for_buffer(
 818        &self,
 819        buffer: &Entity<Buffer>,
 820        project: &Entity<Project>,
 821        cx: &App,
 822    ) -> Option<BufferEditPrediction<'_>> {
 823        let project_state = self.projects.get(&project.entity_id())?;
 824
 825        let CurrentEditPrediction {
 826            requested_by,
 827            prediction,
 828            ..
 829        } = project_state.current_prediction.as_ref()?;
 830
 831        if prediction.targets_buffer(buffer.read(cx)) {
 832            Some(BufferEditPrediction::Local { prediction })
 833        } else {
 834            let show_jump = match requested_by {
 835                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 836                    requested_by_buffer_id == &buffer.entity_id()
 837                }
 838                PredictionRequestedBy::DiagnosticsUpdate => true,
 839            };
 840
 841            if show_jump {
 842                Some(BufferEditPrediction::Jump { prediction })
 843            } else {
 844                None
 845            }
 846        }
 847    }
 848
 849    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 850        match self.edit_prediction_model {
 851            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 852            ZetaEditPredictionModel::Sweep => return,
 853        }
 854
 855        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 856            return;
 857        };
 858
 859        let Some(prediction) = project_state.current_prediction.take() else {
 860            return;
 861        };
 862        let request_id = prediction.prediction.id.to_string();
 863        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 864            project_state.cancel_pending_prediction(pending_prediction, cx);
 865        }
 866
 867        let client = self.client.clone();
 868        let llm_token = self.llm_token.clone();
 869        let app_version = AppVersion::global(cx);
 870        cx.spawn(async move |this, cx| {
 871            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 872                http_client::Url::parse(&predict_edits_url)?
 873            } else {
 874                client
 875                    .http_client()
 876                    .build_zed_llm_url("/predict_edits/accept", &[])?
 877            };
 878
 879            let response = cx
 880                .background_spawn(Self::send_api_request::<()>(
 881                    move |builder| {
 882                        let req = builder.uri(url.as_ref()).body(
 883                            serde_json::to_string(&AcceptEditPredictionBody {
 884                                request_id: request_id.clone(),
 885                            })?
 886                            .into(),
 887                        );
 888                        Ok(req?)
 889                    },
 890                    client,
 891                    llm_token,
 892                    app_version,
 893                ))
 894                .await;
 895
 896            Self::handle_api_response(&this, response, cx)?;
 897            anyhow::Ok(())
 898        })
 899        .detach_and_log_err(cx);
 900    }
 901
 902    async fn handle_rejected_predictions(
 903        rx: UnboundedReceiver<EditPredictionRejection>,
 904        client: Arc<Client>,
 905        llm_token: LlmApiToken,
 906        app_version: Version,
 907        background_executor: BackgroundExecutor,
 908    ) {
 909        let mut rx = std::pin::pin!(rx.peekable());
 910        let mut batched = Vec::new();
 911
 912        while let Some(rejection) = rx.next().await {
 913            batched.push(rejection);
 914
 915            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
 916                select_biased! {
 917                    next = rx.as_mut().peek().fuse() => {
 918                        if next.is_some() {
 919                            continue;
 920                        }
 921                    }
 922                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
 923                }
 924            }
 925
 926            let url = client
 927                .http_client()
 928                .build_zed_llm_url("/predict_edits/reject", &[])
 929                .unwrap();
 930
 931            let flush_count = batched
 932                .len()
 933                // in case items have accumulated after failure
 934                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
 935            let start = batched.len() - flush_count;
 936
 937            let body = RejectEditPredictionsBodyRef {
 938                rejections: &batched[start..],
 939            };
 940
 941            let result = Self::send_api_request::<()>(
 942                |builder| {
 943                    let req = builder
 944                        .uri(url.as_ref())
 945                        .body(serde_json::to_string(&body)?.into());
 946                    anyhow::Ok(req?)
 947                },
 948                client.clone(),
 949                llm_token.clone(),
 950                app_version.clone(),
 951            )
 952            .await;
 953
 954            if result.log_err().is_some() {
 955                batched.drain(start..);
 956            }
 957        }
 958    }
 959
 960    fn reject_current_prediction(
 961        &mut self,
 962        reason: EditPredictionRejectReason,
 963        project: &Entity<Project>,
 964    ) {
 965        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 966            project_state.pending_predictions.clear();
 967            if let Some(prediction) = project_state.current_prediction.take() {
 968                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
 969            }
 970        };
 971    }
 972
 973    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 974        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 975            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 976                if !current_prediction.was_shown {
 977                    current_prediction.was_shown = true;
 978                    self.shown_predictions
 979                        .push_front(current_prediction.prediction.clone());
 980                    if self.shown_predictions.len() > 50 {
 981                        let completion = self.shown_predictions.pop_back().unwrap();
 982                        self.rated_predictions.remove(&completion.id);
 983                    }
 984                }
 985            }
 986        }
 987    }
 988
 989    fn reject_prediction(
 990        &mut self,
 991        prediction_id: EditPredictionId,
 992        reason: EditPredictionRejectReason,
 993        was_shown: bool,
 994    ) {
 995        match self.edit_prediction_model {
 996            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 997            ZetaEditPredictionModel::Sweep => return,
 998        }
 999
1000        self.reject_predictions_tx
1001            .unbounded_send(EditPredictionRejection {
1002                request_id: prediction_id.to_string(),
1003                reason,
1004                was_shown,
1005            })
1006            .log_err();
1007    }
1008
1009    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1010        self.projects
1011            .get(&project.entity_id())
1012            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1013    }
1014
1015    pub fn refresh_prediction_from_buffer(
1016        &mut self,
1017        project: Entity<Project>,
1018        buffer: Entity<Buffer>,
1019        position: language::Anchor,
1020        cx: &mut Context<Self>,
1021    ) {
1022        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1023            let Some(request_task) = this
1024                .update(cx, |this, cx| {
1025                    this.request_prediction(&project, &buffer, position, cx)
1026                })
1027                .log_err()
1028            else {
1029                return Task::ready(anyhow::Ok(None));
1030            };
1031
1032            cx.spawn(async move |_cx| {
1033                request_task.await.map(|prediction_result| {
1034                    prediction_result.map(|prediction_result| {
1035                        (
1036                            prediction_result,
1037                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1038                        )
1039                    })
1040                })
1041            })
1042        })
1043    }
1044
1045    pub fn refresh_prediction_from_diagnostics(
1046        &mut self,
1047        project: Entity<Project>,
1048        cx: &mut Context<Self>,
1049    ) {
1050        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1051            return;
1052        };
1053
1054        // Prefer predictions from buffer
1055        if zeta_project.current_prediction.is_some() {
1056            return;
1057        };
1058
1059        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1060            let Some(open_buffer_task) = project
1061                .update(cx, |project, cx| {
1062                    project
1063                        .active_entry()
1064                        .and_then(|entry| project.path_for_entry(entry, cx))
1065                        .map(|path| project.open_buffer(path, cx))
1066                })
1067                .log_err()
1068                .flatten()
1069            else {
1070                return Task::ready(anyhow::Ok(None));
1071            };
1072
1073            cx.spawn(async move |cx| {
1074                let active_buffer = open_buffer_task.await?;
1075                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1076
1077                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1078                    active_buffer,
1079                    &snapshot,
1080                    Default::default(),
1081                    Default::default(),
1082                    &project,
1083                    cx,
1084                )
1085                .await?
1086                else {
1087                    return anyhow::Ok(None);
1088                };
1089
1090                let Some(prediction_result) = this
1091                    .update(cx, |this, cx| {
1092                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
1093                    })?
1094                    .await?
1095                else {
1096                    return anyhow::Ok(None);
1097                };
1098
1099                this.update(cx, |this, cx| {
1100                    Some((
1101                        if this
1102                            .get_or_init_zeta_project(&project, cx)
1103                            .current_prediction
1104                            .is_none()
1105                        {
1106                            prediction_result
1107                        } else {
1108                            EditPredictionResult {
1109                                id: prediction_result.id,
1110                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1111                            }
1112                        },
1113                        PredictionRequestedBy::DiagnosticsUpdate,
1114                    ))
1115                })
1116            })
1117        });
1118    }
1119
1120    #[cfg(not(test))]
1121    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1122    #[cfg(test)]
1123    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1124
1125    fn queue_prediction_refresh(
1126        &mut self,
1127        project: Entity<Project>,
1128        throttle_entity: EntityId,
1129        cx: &mut Context<Self>,
1130        do_refresh: impl FnOnce(
1131            WeakEntity<Self>,
1132            &mut AsyncApp,
1133        )
1134            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1135        + 'static,
1136    ) {
1137        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1138        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1139        zeta_project.next_pending_prediction_id += 1;
1140        let last_request = zeta_project.last_prediction_refresh;
1141
1142        let task = cx.spawn(async move |this, cx| {
1143            if let Some((last_entity, last_timestamp)) = last_request
1144                && throttle_entity == last_entity
1145                && let Some(timeout) =
1146                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1147            {
1148                cx.background_executor().timer(timeout).await;
1149            }
1150
1151            // If this task was cancelled before the throttle timeout expired,
1152            // do not perform a request.
1153            let mut is_cancelled = true;
1154            this.update(cx, |this, cx| {
1155                let project_state = this.get_or_init_zeta_project(&project, cx);
1156                if !project_state
1157                    .cancelled_predictions
1158                    .remove(&pending_prediction_id)
1159                {
1160                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1161                    is_cancelled = false;
1162                }
1163            })
1164            .ok();
1165            if is_cancelled {
1166                return None;
1167            }
1168
1169            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1170            let new_prediction_id = new_prediction_result
1171                .as_ref()
1172                .map(|(prediction, _)| prediction.id.clone());
1173
1174            // When a prediction completes, remove it from the pending list, and cancel
1175            // any pending predictions that were enqueued before it.
1176            this.update(cx, |this, cx| {
1177                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1178
1179                let is_cancelled = zeta_project
1180                    .cancelled_predictions
1181                    .remove(&pending_prediction_id);
1182
1183                let new_current_prediction = if !is_cancelled
1184                    && let Some((prediction_result, requested_by)) = new_prediction_result
1185                {
1186                    match prediction_result.prediction {
1187                        Ok(prediction) => {
1188                            let new_prediction = CurrentEditPrediction {
1189                                requested_by,
1190                                prediction,
1191                                was_shown: false,
1192                            };
1193
1194                            if let Some(current_prediction) =
1195                                zeta_project.current_prediction.as_ref()
1196                            {
1197                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1198                                {
1199                                    this.reject_current_prediction(
1200                                        EditPredictionRejectReason::Replaced,
1201                                        &project,
1202                                    );
1203
1204                                    Some(new_prediction)
1205                                } else {
1206                                    this.reject_prediction(
1207                                        new_prediction.prediction.id,
1208                                        EditPredictionRejectReason::CurrentPreferred,
1209                                        false,
1210                                    );
1211                                    None
1212                                }
1213                            } else {
1214                                Some(new_prediction)
1215                            }
1216                        }
1217                        Err(reject_reason) => {
1218                            this.reject_prediction(prediction_result.id, reject_reason, false);
1219                            None
1220                        }
1221                    }
1222                } else {
1223                    None
1224                };
1225
1226                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1227
1228                if let Some(new_prediction) = new_current_prediction {
1229                    zeta_project.current_prediction = Some(new_prediction);
1230                }
1231
1232                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1233                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1234                    if pending_prediction.id == pending_prediction_id {
1235                        pending_predictions.remove(ix);
1236                        for pending_prediction in pending_predictions.drain(0..ix) {
1237                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1238                        }
1239                        break;
1240                    }
1241                }
1242                this.get_or_init_zeta_project(&project, cx)
1243                    .pending_predictions = pending_predictions;
1244                cx.notify();
1245            })
1246            .ok();
1247
1248            new_prediction_id
1249        });
1250
1251        if zeta_project.pending_predictions.len() <= 1 {
1252            zeta_project.pending_predictions.push(PendingPrediction {
1253                id: pending_prediction_id,
1254                task,
1255            });
1256        } else if zeta_project.pending_predictions.len() == 2 {
1257            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1258            zeta_project.pending_predictions.push(PendingPrediction {
1259                id: pending_prediction_id,
1260                task,
1261            });
1262            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1263        }
1264    }
1265
1266    pub fn request_prediction(
1267        &mut self,
1268        project: &Entity<Project>,
1269        active_buffer: &Entity<Buffer>,
1270        position: language::Anchor,
1271        cx: &mut Context<Self>,
1272    ) -> Task<Result<Option<EditPredictionResult>>> {
1273        self.request_prediction_internal(
1274            project.clone(),
1275            active_buffer.clone(),
1276            position,
1277            cx.has_flag::<Zeta2FeatureFlag>(),
1278            cx,
1279        )
1280    }
1281
1282    fn request_prediction_internal(
1283        &mut self,
1284        project: Entity<Project>,
1285        active_buffer: Entity<Buffer>,
1286        position: language::Anchor,
1287        allow_jump: bool,
1288        cx: &mut Context<Self>,
1289    ) -> Task<Result<Option<EditPredictionResult>>> {
1290        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1291
1292        self.get_or_init_zeta_project(&project, cx);
1293        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1294        let events = zeta_project.events(cx);
1295        let has_events = !events.is_empty();
1296
1297        let snapshot = active_buffer.read(cx).snapshot();
1298        let cursor_point = position.to_point(&snapshot);
1299        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1300        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1301        let diagnostic_search_range =
1302            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1303
1304        let task = match self.edit_prediction_model {
1305            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1306                self,
1307                &project,
1308                &active_buffer,
1309                snapshot.clone(),
1310                position,
1311                events,
1312                cx,
1313            ),
1314            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1315                &project,
1316                &active_buffer,
1317                snapshot.clone(),
1318                position,
1319                events,
1320                cx,
1321            ),
1322            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1323                &project,
1324                &active_buffer,
1325                snapshot.clone(),
1326                position,
1327                events,
1328                &zeta_project.recent_paths,
1329                diagnostic_search_range.clone(),
1330                cx,
1331            ),
1332        };
1333
1334        cx.spawn(async move |this, cx| {
1335            let prediction = task.await?;
1336
1337            if prediction.is_none() && allow_jump {
1338                let cursor_point = position.to_point(&snapshot);
1339                if has_events
1340                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1341                        active_buffer.clone(),
1342                        &snapshot,
1343                        diagnostic_search_range,
1344                        cursor_point,
1345                        &project,
1346                        cx,
1347                    )
1348                    .await?
1349                {
1350                    return this
1351                        .update(cx, |this, cx| {
1352                            this.request_prediction_internal(
1353                                project,
1354                                jump_buffer,
1355                                jump_position,
1356                                false,
1357                                cx,
1358                            )
1359                        })?
1360                        .await;
1361                }
1362
1363                return anyhow::Ok(None);
1364            }
1365
1366            Ok(prediction)
1367        })
1368    }
1369
1370    async fn next_diagnostic_location(
1371        active_buffer: Entity<Buffer>,
1372        active_buffer_snapshot: &BufferSnapshot,
1373        active_buffer_diagnostic_search_range: Range<Point>,
1374        active_buffer_cursor_point: Point,
1375        project: &Entity<Project>,
1376        cx: &mut AsyncApp,
1377    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1378        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1379        let mut jump_location = active_buffer_snapshot
1380            .diagnostic_groups(None)
1381            .into_iter()
1382            .filter_map(|(_, group)| {
1383                let range = &group.entries[group.primary_ix]
1384                    .range
1385                    .to_point(&active_buffer_snapshot);
1386                if range.overlaps(&active_buffer_diagnostic_search_range) {
1387                    None
1388                } else {
1389                    Some(range.start)
1390                }
1391            })
1392            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1393            .map(|position| {
1394                (
1395                    active_buffer.clone(),
1396                    active_buffer_snapshot.anchor_before(position),
1397                )
1398            });
1399
1400        if jump_location.is_none() {
1401            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1402                let file = buffer.file()?;
1403
1404                Some(ProjectPath {
1405                    worktree_id: file.worktree_id(cx),
1406                    path: file.path().clone(),
1407                })
1408            })?;
1409
1410            let buffer_task = project.update(cx, |project, cx| {
1411                let (path, _, _) = project
1412                    .diagnostic_summaries(false, cx)
1413                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1414                    .max_by_key(|(path, _, _)| {
1415                        // find the buffer with errors that shares most parent directories
1416                        path.path
1417                            .components()
1418                            .zip(
1419                                active_buffer_path
1420                                    .as_ref()
1421                                    .map(|p| p.path.components())
1422                                    .unwrap_or_default(),
1423                            )
1424                            .take_while(|(a, b)| a == b)
1425                            .count()
1426                    })?;
1427
1428                Some(project.open_buffer(path, cx))
1429            })?;
1430
1431            if let Some(buffer_task) = buffer_task {
1432                let closest_buffer = buffer_task.await?;
1433
1434                jump_location = closest_buffer
1435                    .read_with(cx, |buffer, _cx| {
1436                        buffer
1437                            .buffer_diagnostics(None)
1438                            .into_iter()
1439                            .min_by_key(|entry| entry.diagnostic.severity)
1440                            .map(|entry| entry.range.start)
1441                    })?
1442                    .map(|position| (closest_buffer, position));
1443            }
1444        }
1445
1446        anyhow::Ok(jump_location)
1447    }
1448
1449    fn request_prediction_with_zeta2(
1450        &mut self,
1451        project: &Entity<Project>,
1452        active_buffer: &Entity<Buffer>,
1453        active_snapshot: BufferSnapshot,
1454        position: language::Anchor,
1455        events: Vec<Arc<Event>>,
1456        cx: &mut Context<Self>,
1457    ) -> Task<Result<Option<EditPredictionResult>>> {
1458        let project_state = self.projects.get(&project.entity_id());
1459
1460        let index_state = project_state.and_then(|state| {
1461            state
1462                .syntax_index
1463                .as_ref()
1464                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1465        });
1466        let options = self.options.clone();
1467        let buffer_snapshotted_at = Instant::now();
1468        let Some(excerpt_path) = active_snapshot
1469            .file()
1470            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1471        else {
1472            return Task::ready(Err(anyhow!("No file path for excerpt")));
1473        };
1474        let client = self.client.clone();
1475        let llm_token = self.llm_token.clone();
1476        let app_version = AppVersion::global(cx);
1477        let worktree_snapshots = project
1478            .read(cx)
1479            .worktrees(cx)
1480            .map(|worktree| worktree.read(cx).snapshot())
1481            .collect::<Vec<_>>();
1482        let debug_tx = self.debug_tx.clone();
1483
1484        let diagnostics = active_snapshot.diagnostic_sets().clone();
1485
1486        let file = active_buffer.read(cx).file();
1487        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1488            let mut path = f.worktree.read(cx).absolutize(&f.path);
1489            if path.pop() { Some(path) } else { None }
1490        });
1491
1492        // TODO data collection
1493        let can_collect_data = file
1494            .as_ref()
1495            .map_or(false, |file| self.can_collect_file(project, file, cx));
1496
1497        let empty_context_files = HashMap::default();
1498        let context_files = project_state
1499            .and_then(|project_state| project_state.context.as_ref())
1500            .unwrap_or(&empty_context_files);
1501
1502        #[cfg(feature = "eval-support")]
1503        let parsed_fut = futures::future::join_all(
1504            context_files
1505                .keys()
1506                .map(|buffer| buffer.read(cx).parsing_idle()),
1507        );
1508
1509        let mut included_files = context_files
1510            .iter()
1511            .filter_map(|(buffer_entity, ranges)| {
1512                let buffer = buffer_entity.read(cx);
1513                Some((
1514                    buffer_entity.clone(),
1515                    buffer.snapshot(),
1516                    buffer.file()?.full_path(cx).into(),
1517                    ranges.clone(),
1518                ))
1519            })
1520            .collect::<Vec<_>>();
1521
1522        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1523            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1524        });
1525
1526        #[cfg(feature = "eval-support")]
1527        let eval_cache = self.eval_cache.clone();
1528
1529        let request_task = cx.background_spawn({
1530            let active_buffer = active_buffer.clone();
1531            async move {
1532                #[cfg(feature = "eval-support")]
1533                parsed_fut.await;
1534
1535                let index_state = if let Some(index_state) = index_state {
1536                    Some(index_state.lock_owned().await)
1537                } else {
1538                    None
1539                };
1540
1541                let cursor_offset = position.to_offset(&active_snapshot);
1542                let cursor_point = cursor_offset.to_point(&active_snapshot);
1543
1544                let before_retrieval = Instant::now();
1545
1546                let (diagnostic_groups, diagnostic_groups_truncated) =
1547                    Self::gather_nearby_diagnostics(
1548                        cursor_offset,
1549                        &diagnostics,
1550                        &active_snapshot,
1551                        options.max_diagnostic_bytes,
1552                    );
1553
1554                let cloud_request = match options.context {
1555                    ContextMode::Agentic(context_options) => {
1556                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1557                            cursor_point,
1558                            &active_snapshot,
1559                            &context_options.excerpt,
1560                            index_state.as_deref(),
1561                        ) else {
1562                            return Ok((None, None));
1563                        };
1564
1565                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1566                            ..active_snapshot.anchor_before(excerpt.range.end);
1567
1568                        if let Some(buffer_ix) =
1569                            included_files.iter().position(|(_, snapshot, _, _)| {
1570                                snapshot.remote_id() == active_snapshot.remote_id()
1571                            })
1572                        {
1573                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1574                            ranges.push(excerpt_anchor_range);
1575                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1576                            let last_ix = included_files.len() - 1;
1577                            included_files.swap(buffer_ix, last_ix);
1578                        } else {
1579                            included_files.push((
1580                                active_buffer.clone(),
1581                                active_snapshot.clone(),
1582                                excerpt_path.clone(),
1583                                vec![excerpt_anchor_range],
1584                            ));
1585                        }
1586
1587                        let included_files = included_files
1588                            .iter()
1589                            .map(|(_, snapshot, path, ranges)| {
1590                                let ranges = ranges
1591                                    .iter()
1592                                    .map(|range| {
1593                                        let point_range = range.to_point(&snapshot);
1594                                        Line(point_range.start.row)..Line(point_range.end.row)
1595                                    })
1596                                    .collect::<Vec<_>>();
1597                                let excerpts = assemble_excerpts(&snapshot, ranges);
1598                                predict_edits_v3::IncludedFile {
1599                                    path: path.clone(),
1600                                    max_row: Line(snapshot.max_point().row),
1601                                    excerpts,
1602                                }
1603                            })
1604                            .collect::<Vec<_>>();
1605
1606                        predict_edits_v3::PredictEditsRequest {
1607                            excerpt_path,
1608                            excerpt: String::new(),
1609                            excerpt_line_range: Line(0)..Line(0),
1610                            excerpt_range: 0..0,
1611                            cursor_point: predict_edits_v3::Point {
1612                                line: predict_edits_v3::Line(cursor_point.row),
1613                                column: cursor_point.column,
1614                            },
1615                            included_files,
1616                            referenced_declarations: vec![],
1617                            events,
1618                            can_collect_data,
1619                            diagnostic_groups,
1620                            diagnostic_groups_truncated,
1621                            debug_info: debug_tx.is_some(),
1622                            prompt_max_bytes: Some(options.max_prompt_bytes),
1623                            prompt_format: options.prompt_format,
1624                            // TODO [zeta2]
1625                            signatures: vec![],
1626                            excerpt_parent: None,
1627                            git_info: None,
1628                        }
1629                    }
1630                    ContextMode::Syntax(context_options) => {
1631                        let Some(context) = EditPredictionContext::gather_context(
1632                            cursor_point,
1633                            &active_snapshot,
1634                            parent_abs_path.as_deref(),
1635                            &context_options,
1636                            index_state.as_deref(),
1637                        ) else {
1638                            return Ok((None, None));
1639                        };
1640
1641                        make_syntax_context_cloud_request(
1642                            excerpt_path,
1643                            context,
1644                            events,
1645                            can_collect_data,
1646                            diagnostic_groups,
1647                            diagnostic_groups_truncated,
1648                            None,
1649                            debug_tx.is_some(),
1650                            &worktree_snapshots,
1651                            index_state.as_deref(),
1652                            Some(options.max_prompt_bytes),
1653                            options.prompt_format,
1654                        )
1655                    }
1656                };
1657
1658                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1659
1660                let inputs = EditPredictionInputs {
1661                    included_files: cloud_request.included_files,
1662                    events: cloud_request.events,
1663                    cursor_point: cloud_request.cursor_point,
1664                    cursor_path: cloud_request.excerpt_path,
1665                };
1666
1667                let retrieval_time = Instant::now() - before_retrieval;
1668
1669                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1670                    let (response_tx, response_rx) = oneshot::channel();
1671
1672                    debug_tx
1673                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1674                            ZetaEditPredictionDebugInfo {
1675                                inputs: inputs.clone(),
1676                                retrieval_time,
1677                                buffer: active_buffer.downgrade(),
1678                                local_prompt: match prompt_result.as_ref() {
1679                                    Ok((prompt, _)) => Ok(prompt.clone()),
1680                                    Err(err) => Err(err.to_string()),
1681                                },
1682                                position,
1683                                response_rx,
1684                            },
1685                        ))
1686                        .ok();
1687                    Some(response_tx)
1688                } else {
1689                    None
1690                };
1691
1692                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1693                    if let Some(debug_response_tx) = debug_response_tx {
1694                        debug_response_tx
1695                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1696                            .ok();
1697                    }
1698                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1699                }
1700
1701                let (prompt, _) = prompt_result?;
1702                let generation_params =
1703                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1704                let request = open_ai::Request {
1705                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1706                    messages: vec![open_ai::RequestMessage::User {
1707                        content: open_ai::MessageContent::Plain(prompt),
1708                    }],
1709                    stream: false,
1710                    max_completion_tokens: None,
1711                    stop: generation_params.stop.unwrap_or_default(),
1712                    temperature: generation_params.temperature.unwrap_or(0.7),
1713                    tool_choice: None,
1714                    parallel_tool_calls: None,
1715                    tools: vec![],
1716                    prompt_cache_key: None,
1717                    reasoning_effort: None,
1718                };
1719
1720                log::trace!("Sending edit prediction request");
1721
1722                let before_request = Instant::now();
1723                let response = Self::send_raw_llm_request(
1724                    request,
1725                    client,
1726                    llm_token,
1727                    app_version,
1728                    #[cfg(feature = "eval-support")]
1729                    eval_cache,
1730                    #[cfg(feature = "eval-support")]
1731                    EvalCacheEntryKind::Prediction,
1732                )
1733                .await;
1734                let received_response_at = Instant::now();
1735                let request_time = received_response_at - before_request;
1736
1737                log::trace!("Got edit prediction response");
1738
1739                if let Some(debug_response_tx) = debug_response_tx {
1740                    debug_response_tx
1741                        .send((
1742                            response
1743                                .as_ref()
1744                                .map_err(|err| err.to_string())
1745                                .map(|response| response.0.clone()),
1746                            request_time,
1747                        ))
1748                        .ok();
1749                }
1750
1751                let (res, usage) = response?;
1752                let request_id = EditPredictionId(res.id.clone().into());
1753                let Some(mut output_text) = text_from_response(res) else {
1754                    return Ok((Some((request_id, None)), usage));
1755                };
1756
1757                if output_text.contains(CURSOR_MARKER) {
1758                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1759                    output_text = output_text.replace(CURSOR_MARKER, "");
1760                }
1761
1762                let get_buffer_from_context = |path: &Path| {
1763                    included_files
1764                        .iter()
1765                        .find_map(|(_, buffer, probe_path, ranges)| {
1766                            if probe_path.as_ref() == path {
1767                                Some((buffer, ranges.as_slice()))
1768                            } else {
1769                                None
1770                            }
1771                        })
1772                };
1773
1774                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1775                    PromptFormat::NumLinesUniDiff => {
1776                        // TODO: Implement parsing of multi-file diffs
1777                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1778                    }
1779                    PromptFormat::Minimal
1780                    | PromptFormat::MinimalQwen
1781                    | PromptFormat::SeedCoder1120 => {
1782                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1783                            let edits = vec![];
1784                            (&active_snapshot, edits)
1785                        } else {
1786                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1787                        }
1788                    }
1789                    PromptFormat::OldTextNewText => {
1790                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1791                            .await?
1792                    }
1793                    _ => {
1794                        bail!("unsupported prompt format {}", options.prompt_format)
1795                    }
1796                };
1797
1798                let edited_buffer = included_files
1799                    .iter()
1800                    .find_map(|(buffer, snapshot, _, _)| {
1801                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1802                            Some(buffer.clone())
1803                        } else {
1804                            None
1805                        }
1806                    })
1807                    .context("Failed to find buffer in included_buffers")?;
1808
1809                anyhow::Ok((
1810                    Some((
1811                        request_id,
1812                        Some((
1813                            inputs,
1814                            edited_buffer,
1815                            edited_buffer_snapshot.clone(),
1816                            edits,
1817                            received_response_at,
1818                        )),
1819                    )),
1820                    usage,
1821                ))
1822            }
1823        });
1824
1825        cx.spawn({
1826            async move |this, cx| {
1827                let Some((id, prediction)) =
1828                    Self::handle_api_response(&this, request_task.await, cx)?
1829                else {
1830                    return Ok(None);
1831                };
1832
1833                let Some((
1834                    inputs,
1835                    edited_buffer,
1836                    edited_buffer_snapshot,
1837                    edits,
1838                    received_response_at,
1839                )) = prediction
1840                else {
1841                    return Ok(Some(EditPredictionResult {
1842                        id,
1843                        prediction: Err(EditPredictionRejectReason::Empty),
1844                    }));
1845                };
1846
1847                // TODO telemetry: duration, etc
1848                Ok(Some(
1849                    EditPredictionResult::new(
1850                        id,
1851                        &edited_buffer,
1852                        &edited_buffer_snapshot,
1853                        edits.into(),
1854                        buffer_snapshotted_at,
1855                        received_response_at,
1856                        inputs,
1857                        cx,
1858                    )
1859                    .await,
1860                ))
1861            }
1862        })
1863    }
1864
1865    async fn send_raw_llm_request(
1866        request: open_ai::Request,
1867        client: Arc<Client>,
1868        llm_token: LlmApiToken,
1869        app_version: Version,
1870        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1871        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1872    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1873        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1874            http_client::Url::parse(&predict_edits_url)?
1875        } else {
1876            client
1877                .http_client()
1878                .build_zed_llm_url("/predict_edits/raw", &[])?
1879        };
1880
1881        #[cfg(feature = "eval-support")]
1882        let cache_key = if let Some(cache) = eval_cache {
1883            use collections::FxHasher;
1884            use std::hash::{Hash, Hasher};
1885
1886            let mut hasher = FxHasher::default();
1887            url.hash(&mut hasher);
1888            let request_str = serde_json::to_string_pretty(&request)?;
1889            request_str.hash(&mut hasher);
1890            let hash = hasher.finish();
1891
1892            let key = (eval_cache_kind, hash);
1893            if let Some(response_str) = cache.read(key) {
1894                return Ok((serde_json::from_str(&response_str)?, None));
1895            }
1896
1897            Some((cache, request_str, key))
1898        } else {
1899            None
1900        };
1901
1902        let (response, usage) = Self::send_api_request(
1903            |builder| {
1904                let req = builder
1905                    .uri(url.as_ref())
1906                    .body(serde_json::to_string(&request)?.into());
1907                Ok(req?)
1908            },
1909            client,
1910            llm_token,
1911            app_version,
1912        )
1913        .await?;
1914
1915        #[cfg(feature = "eval-support")]
1916        if let Some((cache, request, key)) = cache_key {
1917            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1918        }
1919
1920        Ok((response, usage))
1921    }
1922
1923    fn handle_api_response<T>(
1924        this: &WeakEntity<Self>,
1925        response: Result<(T, Option<EditPredictionUsage>)>,
1926        cx: &mut gpui::AsyncApp,
1927    ) -> Result<T> {
1928        match response {
1929            Ok((data, usage)) => {
1930                if let Some(usage) = usage {
1931                    this.update(cx, |this, cx| {
1932                        this.user_store.update(cx, |user_store, cx| {
1933                            user_store.update_edit_prediction_usage(usage, cx);
1934                        });
1935                    })
1936                    .ok();
1937                }
1938                Ok(data)
1939            }
1940            Err(err) => {
1941                if err.is::<ZedUpdateRequiredError>() {
1942                    cx.update(|cx| {
1943                        this.update(cx, |this, _cx| {
1944                            this.update_required = true;
1945                        })
1946                        .ok();
1947
1948                        let error_message: SharedString = err.to_string().into();
1949                        show_app_notification(
1950                            NotificationId::unique::<ZedUpdateRequiredError>(),
1951                            cx,
1952                            move |cx| {
1953                                cx.new(|cx| {
1954                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1955                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1956                                })
1957                            },
1958                        );
1959                    })
1960                    .ok();
1961                }
1962                Err(err)
1963            }
1964        }
1965    }
1966
1967    async fn send_api_request<Res>(
1968        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1969        client: Arc<Client>,
1970        llm_token: LlmApiToken,
1971        app_version: Version,
1972    ) -> Result<(Res, Option<EditPredictionUsage>)>
1973    where
1974        Res: DeserializeOwned,
1975    {
1976        let http_client = client.http_client();
1977        let mut token = llm_token.acquire(&client).await?;
1978        let mut did_retry = false;
1979
1980        loop {
1981            let request_builder = http_client::Request::builder().method(Method::POST);
1982
1983            let request = build(
1984                request_builder
1985                    .header("Content-Type", "application/json")
1986                    .header("Authorization", format!("Bearer {}", token))
1987                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1988            )?;
1989
1990            let mut response = http_client.send(request).await?;
1991
1992            if let Some(minimum_required_version) = response
1993                .headers()
1994                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1995                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1996            {
1997                anyhow::ensure!(
1998                    app_version >= minimum_required_version,
1999                    ZedUpdateRequiredError {
2000                        minimum_version: minimum_required_version
2001                    }
2002                );
2003            }
2004
2005            if response.status().is_success() {
2006                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2007
2008                let mut body = Vec::new();
2009                response.body_mut().read_to_end(&mut body).await?;
2010                return Ok((serde_json::from_slice(&body)?, usage));
2011            } else if !did_retry
2012                && response
2013                    .headers()
2014                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2015                    .is_some()
2016            {
2017                did_retry = true;
2018                token = llm_token.refresh(&client).await?;
2019            } else {
2020                let mut body = String::new();
2021                response.body_mut().read_to_string(&mut body).await?;
2022                anyhow::bail!(
2023                    "Request failed with status: {:?}\nBody: {}",
2024                    response.status(),
2025                    body
2026                );
2027            }
2028        }
2029    }
2030
2031    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2032    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2033
2034    // Refresh the related excerpts when the user just beguns editing after
2035    // an idle period, and after they pause editing.
2036    fn refresh_context_if_needed(
2037        &mut self,
2038        project: &Entity<Project>,
2039        buffer: &Entity<language::Buffer>,
2040        cursor_position: language::Anchor,
2041        cx: &mut Context<Self>,
2042    ) {
2043        if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2044            return;
2045        }
2046
2047        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2048            return;
2049        }
2050
2051        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2052            return;
2053        };
2054
2055        let now = Instant::now();
2056        let was_idle = zeta_project
2057            .refresh_context_timestamp
2058            .map_or(true, |timestamp| {
2059                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2060            });
2061        zeta_project.refresh_context_timestamp = Some(now);
2062        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2063            let buffer = buffer.clone();
2064            let project = project.clone();
2065            async move |this, cx| {
2066                if was_idle {
2067                    log::debug!("refetching edit prediction context after idle");
2068                } else {
2069                    cx.background_executor()
2070                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2071                        .await;
2072                    log::debug!("refetching edit prediction context after pause");
2073                }
2074                this.update(cx, |this, cx| {
2075                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2076
2077                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2078                        zeta_project.refresh_context_task = Some(task.log_err());
2079                    };
2080                })
2081                .ok()
2082            }
2083        }));
2084    }
2085
2086    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2087    // and avoid spawning more than one concurrent task.
2088    pub fn refresh_context(
2089        &mut self,
2090        project: Entity<Project>,
2091        buffer: Entity<language::Buffer>,
2092        cursor_position: language::Anchor,
2093        cx: &mut Context<Self>,
2094    ) -> Task<Result<()>> {
2095        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2096            return Task::ready(anyhow::Ok(()));
2097        };
2098
2099        let ContextMode::Agentic(options) = &self.options().context else {
2100            return Task::ready(anyhow::Ok(()));
2101        };
2102
2103        let snapshot = buffer.read(cx).snapshot();
2104        let cursor_point = cursor_position.to_point(&snapshot);
2105        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2106            cursor_point,
2107            &snapshot,
2108            &options.excerpt,
2109            None,
2110        ) else {
2111            return Task::ready(Ok(()));
2112        };
2113
2114        let app_version = AppVersion::global(cx);
2115        let client = self.client.clone();
2116        let llm_token = self.llm_token.clone();
2117        let debug_tx = self.debug_tx.clone();
2118        let current_file_path: Arc<Path> = snapshot
2119            .file()
2120            .map(|f| f.full_path(cx).into())
2121            .unwrap_or_else(|| Path::new("untitled").into());
2122
2123        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2124            predict_edits_v3::PlanContextRetrievalRequest {
2125                excerpt: cursor_excerpt.text(&snapshot).body,
2126                excerpt_path: current_file_path,
2127                excerpt_line_range: cursor_excerpt.line_range,
2128                cursor_file_max_row: Line(snapshot.max_point().row),
2129                events: zeta_project.events(cx),
2130            },
2131        ) {
2132            Ok(prompt) => prompt,
2133            Err(err) => {
2134                return Task::ready(Err(err));
2135            }
2136        };
2137
2138        if let Some(debug_tx) = &debug_tx {
2139            debug_tx
2140                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2141                    ZetaContextRetrievalStartedDebugInfo {
2142                        project: project.clone(),
2143                        timestamp: Instant::now(),
2144                        search_prompt: prompt.clone(),
2145                    },
2146                ))
2147                .ok();
2148        }
2149
2150        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2151            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2152                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2153            );
2154
2155            let description = schema
2156                .get("description")
2157                .and_then(|description| description.as_str())
2158                .unwrap()
2159                .to_string();
2160
2161            (schema.into(), description)
2162        });
2163
2164        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2165
2166        let request = open_ai::Request {
2167            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2168            messages: vec![open_ai::RequestMessage::User {
2169                content: open_ai::MessageContent::Plain(prompt),
2170            }],
2171            stream: false,
2172            max_completion_tokens: None,
2173            stop: Default::default(),
2174            temperature: 0.7,
2175            tool_choice: None,
2176            parallel_tool_calls: None,
2177            tools: vec![open_ai::ToolDefinition::Function {
2178                function: FunctionDefinition {
2179                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2180                    description: Some(tool_description),
2181                    parameters: Some(tool_schema),
2182                },
2183            }],
2184            prompt_cache_key: None,
2185            reasoning_effort: None,
2186        };
2187
2188        #[cfg(feature = "eval-support")]
2189        let eval_cache = self.eval_cache.clone();
2190
2191        cx.spawn(async move |this, cx| {
2192            log::trace!("Sending search planning request");
2193            let response = Self::send_raw_llm_request(
2194                request,
2195                client,
2196                llm_token,
2197                app_version,
2198                #[cfg(feature = "eval-support")]
2199                eval_cache.clone(),
2200                #[cfg(feature = "eval-support")]
2201                EvalCacheEntryKind::Context,
2202            )
2203            .await;
2204            let mut response = Self::handle_api_response(&this, response, cx)?;
2205            log::trace!("Got search planning response");
2206
2207            let choice = response
2208                .choices
2209                .pop()
2210                .context("No choices in retrieval response")?;
2211            let open_ai::RequestMessage::Assistant {
2212                content: _,
2213                tool_calls,
2214            } = choice.message
2215            else {
2216                anyhow::bail!("Retrieval response didn't include an assistant message");
2217            };
2218
2219            let mut queries: Vec<SearchToolQuery> = Vec::new();
2220            for tool_call in tool_calls {
2221                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2222                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2223                    log::warn!(
2224                        "Context retrieval response tried to call an unknown tool: {}",
2225                        function.name
2226                    );
2227
2228                    continue;
2229                }
2230
2231                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2232                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2233                queries.extend(input.queries);
2234            }
2235
2236            if let Some(debug_tx) = &debug_tx {
2237                debug_tx
2238                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2239                        ZetaSearchQueryDebugInfo {
2240                            project: project.clone(),
2241                            timestamp: Instant::now(),
2242                            search_queries: queries.clone(),
2243                        },
2244                    ))
2245                    .ok();
2246            }
2247
2248            log::trace!("Running retrieval search: {queries:#?}");
2249
2250            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2251                queries,
2252                project.clone(),
2253                #[cfg(feature = "eval-support")]
2254                eval_cache,
2255                cx,
2256            )
2257            .await;
2258
2259            log::trace!("Search queries executed");
2260
2261            if let Some(debug_tx) = &debug_tx {
2262                debug_tx
2263                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2264                        ZetaContextRetrievalDebugInfo {
2265                            project: project.clone(),
2266                            timestamp: Instant::now(),
2267                        },
2268                    ))
2269                    .ok();
2270            }
2271
2272            this.update(cx, |this, _cx| {
2273                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2274                    return Ok(());
2275                };
2276                zeta_project.refresh_context_task.take();
2277                if let Some(debug_tx) = &this.debug_tx {
2278                    debug_tx
2279                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2280                            ZetaContextRetrievalDebugInfo {
2281                                project,
2282                                timestamp: Instant::now(),
2283                            },
2284                        ))
2285                        .ok();
2286                }
2287                match related_excerpts_result {
2288                    Ok(excerpts) => {
2289                        zeta_project.context = Some(excerpts);
2290                        Ok(())
2291                    }
2292                    Err(error) => Err(error),
2293                }
2294            })?
2295        })
2296    }
2297
2298    pub fn set_context(
2299        &mut self,
2300        project: Entity<Project>,
2301        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2302    ) {
2303        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2304            zeta_project.context = Some(context);
2305        }
2306    }
2307
2308    fn gather_nearby_diagnostics(
2309        cursor_offset: usize,
2310        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2311        snapshot: &BufferSnapshot,
2312        max_diagnostics_bytes: usize,
2313    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2314        // TODO: Could make this more efficient
2315        let mut diagnostic_groups = Vec::new();
2316        for (language_server_id, diagnostics) in diagnostic_sets {
2317            let mut groups = Vec::new();
2318            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2319            diagnostic_groups.extend(
2320                groups
2321                    .into_iter()
2322                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2323            );
2324        }
2325
2326        // sort by proximity to cursor
2327        diagnostic_groups.sort_by_key(|group| {
2328            let range = &group.entries[group.primary_ix].range;
2329            if range.start >= cursor_offset {
2330                range.start - cursor_offset
2331            } else if cursor_offset >= range.end {
2332                cursor_offset - range.end
2333            } else {
2334                (cursor_offset - range.start).min(range.end - cursor_offset)
2335            }
2336        });
2337
2338        let mut results = Vec::new();
2339        let mut diagnostic_groups_truncated = false;
2340        let mut diagnostics_byte_count = 0;
2341        for group in diagnostic_groups {
2342            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2343            diagnostics_byte_count += raw_value.get().len();
2344            if diagnostics_byte_count > max_diagnostics_bytes {
2345                diagnostic_groups_truncated = true;
2346                break;
2347            }
2348            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2349        }
2350
2351        (results, diagnostic_groups_truncated)
2352    }
2353
2354    // TODO: Dedupe with similar code in request_prediction?
2355    pub fn cloud_request_for_zeta_cli(
2356        &mut self,
2357        project: &Entity<Project>,
2358        buffer: &Entity<Buffer>,
2359        position: language::Anchor,
2360        cx: &mut Context<Self>,
2361    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2362        let project_state = self.projects.get(&project.entity_id());
2363
2364        let index_state = project_state.and_then(|state| {
2365            state
2366                .syntax_index
2367                .as_ref()
2368                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2369        });
2370        let options = self.options.clone();
2371        let snapshot = buffer.read(cx).snapshot();
2372        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2373            return Task::ready(Err(anyhow!("No file path for excerpt")));
2374        };
2375        let worktree_snapshots = project
2376            .read(cx)
2377            .worktrees(cx)
2378            .map(|worktree| worktree.read(cx).snapshot())
2379            .collect::<Vec<_>>();
2380
2381        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2382            let mut path = f.worktree.read(cx).absolutize(&f.path);
2383            if path.pop() { Some(path) } else { None }
2384        });
2385
2386        cx.background_spawn(async move {
2387            let index_state = if let Some(index_state) = index_state {
2388                Some(index_state.lock_owned().await)
2389            } else {
2390                None
2391            };
2392
2393            let cursor_point = position.to_point(&snapshot);
2394
2395            let debug_info = true;
2396            EditPredictionContext::gather_context(
2397                cursor_point,
2398                &snapshot,
2399                parent_abs_path.as_deref(),
2400                match &options.context {
2401                    ContextMode::Agentic(_) => {
2402                        // TODO
2403                        panic!("Llm mode not supported in zeta cli yet");
2404                    }
2405                    ContextMode::Syntax(edit_prediction_context_options) => {
2406                        edit_prediction_context_options
2407                    }
2408                },
2409                index_state.as_deref(),
2410            )
2411            .context("Failed to select excerpt")
2412            .map(|context| {
2413                make_syntax_context_cloud_request(
2414                    excerpt_path.into(),
2415                    context,
2416                    // TODO pass everything
2417                    Vec::new(),
2418                    false,
2419                    Vec::new(),
2420                    false,
2421                    None,
2422                    debug_info,
2423                    &worktree_snapshots,
2424                    index_state.as_deref(),
2425                    Some(options.max_prompt_bytes),
2426                    options.prompt_format,
2427                )
2428            })
2429        })
2430    }
2431
2432    pub fn wait_for_initial_indexing(
2433        &mut self,
2434        project: &Entity<Project>,
2435        cx: &mut Context<Self>,
2436    ) -> Task<Result<()>> {
2437        let zeta_project = self.get_or_init_zeta_project(project, cx);
2438        if let Some(syntax_index) = &zeta_project.syntax_index {
2439            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2440        } else {
2441            Task::ready(Ok(()))
2442        }
2443    }
2444
2445    fn is_file_open_source(
2446        &self,
2447        project: &Entity<Project>,
2448        file: &Arc<dyn File>,
2449        cx: &App,
2450    ) -> bool {
2451        if !file.is_local() || file.is_private() {
2452            return false;
2453        }
2454        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2455            return false;
2456        };
2457        zeta_project
2458            .license_detection_watchers
2459            .get(&file.worktree_id(cx))
2460            .as_ref()
2461            .is_some_and(|watcher| watcher.is_project_open_source())
2462    }
2463
2464    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2465        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2466    }
2467
2468    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2469        if !self.data_collection_choice.is_enabled() {
2470            return false;
2471        }
2472        events.iter().all(|event| {
2473            matches!(
2474                event.as_ref(),
2475                Event::BufferChange {
2476                    in_open_source_repo: true,
2477                    ..
2478                }
2479            )
2480        })
2481    }
2482
2483    fn load_data_collection_choice() -> DataCollectionChoice {
2484        let choice = KEY_VALUE_STORE
2485            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2486            .log_err()
2487            .flatten();
2488
2489        match choice.as_deref() {
2490            Some("true") => DataCollectionChoice::Enabled,
2491            Some("false") => DataCollectionChoice::Disabled,
2492            Some(_) => {
2493                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2494                DataCollectionChoice::NotAnswered
2495            }
2496            None => DataCollectionChoice::NotAnswered,
2497        }
2498    }
2499
2500    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2501        self.shown_predictions.iter()
2502    }
2503
2504    pub fn shown_completions_len(&self) -> usize {
2505        self.shown_predictions.len()
2506    }
2507
2508    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2509        self.rated_predictions.contains(id)
2510    }
2511
2512    pub fn rate_prediction(
2513        &mut self,
2514        prediction: &EditPrediction,
2515        rating: EditPredictionRating,
2516        feedback: String,
2517        cx: &mut Context<Self>,
2518    ) {
2519        self.rated_predictions.insert(prediction.id.clone());
2520        telemetry::event!(
2521            "Edit Prediction Rated",
2522            rating,
2523            inputs = prediction.inputs,
2524            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2525            feedback
2526        );
2527        self.client.telemetry().flush_events().detach();
2528        cx.notify();
2529    }
2530}
2531
2532pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2533    let choice = res.choices.pop()?;
2534    let output_text = match choice.message {
2535        open_ai::RequestMessage::Assistant {
2536            content: Some(open_ai::MessageContent::Plain(content)),
2537            ..
2538        } => content,
2539        open_ai::RequestMessage::Assistant {
2540            content: Some(open_ai::MessageContent::Multipart(mut content)),
2541            ..
2542        } => {
2543            if content.is_empty() {
2544                log::error!("No output from Baseten completion response");
2545                return None;
2546            }
2547
2548            match content.remove(0) {
2549                open_ai::MessagePart::Text { text } => text,
2550                open_ai::MessagePart::Image { .. } => {
2551                    log::error!("Expected text, got an image");
2552                    return None;
2553                }
2554            }
2555        }
2556        _ => {
2557            log::error!("Invalid response message: {:?}", choice.message);
2558            return None;
2559        }
2560    };
2561    Some(output_text)
2562}
2563
2564#[derive(Error, Debug)]
2565#[error(
2566    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2567)]
2568pub struct ZedUpdateRequiredError {
2569    minimum_version: Version,
2570}
2571
2572fn make_syntax_context_cloud_request(
2573    excerpt_path: Arc<Path>,
2574    context: EditPredictionContext,
2575    events: Vec<Arc<predict_edits_v3::Event>>,
2576    can_collect_data: bool,
2577    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2578    diagnostic_groups_truncated: bool,
2579    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2580    debug_info: bool,
2581    worktrees: &Vec<worktree::Snapshot>,
2582    index_state: Option<&SyntaxIndexState>,
2583    prompt_max_bytes: Option<usize>,
2584    prompt_format: PromptFormat,
2585) -> predict_edits_v3::PredictEditsRequest {
2586    let mut signatures = Vec::new();
2587    let mut declaration_to_signature_index = HashMap::default();
2588    let mut referenced_declarations = Vec::new();
2589
2590    for snippet in context.declarations {
2591        let project_entry_id = snippet.declaration.project_entry_id();
2592        let Some(path) = worktrees.iter().find_map(|worktree| {
2593            worktree.entry_for_id(project_entry_id).map(|entry| {
2594                let mut full_path = RelPathBuf::new();
2595                full_path.push(worktree.root_name());
2596                full_path.push(&entry.path);
2597                full_path
2598            })
2599        }) else {
2600            continue;
2601        };
2602
2603        let parent_index = index_state.and_then(|index_state| {
2604            snippet.declaration.parent().and_then(|parent| {
2605                add_signature(
2606                    parent,
2607                    &mut declaration_to_signature_index,
2608                    &mut signatures,
2609                    index_state,
2610                )
2611            })
2612        });
2613
2614        let (text, text_is_truncated) = snippet.declaration.item_text();
2615        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2616            path: path.as_std_path().into(),
2617            text: text.into(),
2618            range: snippet.declaration.item_line_range(),
2619            text_is_truncated,
2620            signature_range: snippet.declaration.signature_range_in_item_text(),
2621            parent_index,
2622            signature_score: snippet.score(DeclarationStyle::Signature),
2623            declaration_score: snippet.score(DeclarationStyle::Declaration),
2624            score_components: snippet.components,
2625        });
2626    }
2627
2628    let excerpt_parent = index_state.and_then(|index_state| {
2629        context
2630            .excerpt
2631            .parent_declarations
2632            .last()
2633            .and_then(|(parent, _)| {
2634                add_signature(
2635                    *parent,
2636                    &mut declaration_to_signature_index,
2637                    &mut signatures,
2638                    index_state,
2639                )
2640            })
2641    });
2642
2643    predict_edits_v3::PredictEditsRequest {
2644        excerpt_path,
2645        excerpt: context.excerpt_text.body,
2646        excerpt_line_range: context.excerpt.line_range,
2647        excerpt_range: context.excerpt.range,
2648        cursor_point: predict_edits_v3::Point {
2649            line: predict_edits_v3::Line(context.cursor_point.row),
2650            column: context.cursor_point.column,
2651        },
2652        referenced_declarations,
2653        included_files: vec![],
2654        signatures,
2655        excerpt_parent,
2656        events,
2657        can_collect_data,
2658        diagnostic_groups,
2659        diagnostic_groups_truncated,
2660        git_info,
2661        debug_info,
2662        prompt_max_bytes,
2663        prompt_format,
2664    }
2665}
2666
2667fn add_signature(
2668    declaration_id: DeclarationId,
2669    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2670    signatures: &mut Vec<Signature>,
2671    index: &SyntaxIndexState,
2672) -> Option<usize> {
2673    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2674        return Some(*signature_index);
2675    }
2676    let Some(parent_declaration) = index.declaration(declaration_id) else {
2677        log::error!("bug: missing parent declaration");
2678        return None;
2679    };
2680    let parent_index = parent_declaration.parent().and_then(|parent| {
2681        add_signature(parent, declaration_to_signature_index, signatures, index)
2682    });
2683    let (text, text_is_truncated) = parent_declaration.signature_text();
2684    let signature_index = signatures.len();
2685    signatures.push(Signature {
2686        text: text.into(),
2687        text_is_truncated,
2688        parent_index,
2689        range: parent_declaration.signature_line_range(),
2690    });
2691    declaration_to_signature_index.insert(declaration_id, signature_index);
2692    Some(signature_index)
2693}
2694
2695#[cfg(feature = "eval-support")]
2696pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2697
2698#[cfg(feature = "eval-support")]
2699#[derive(Debug, Clone, Copy, PartialEq)]
2700pub enum EvalCacheEntryKind {
2701    Context,
2702    Search,
2703    Prediction,
2704}
2705
2706#[cfg(feature = "eval-support")]
2707impl std::fmt::Display for EvalCacheEntryKind {
2708    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2709        match self {
2710            EvalCacheEntryKind::Search => write!(f, "search"),
2711            EvalCacheEntryKind::Context => write!(f, "context"),
2712            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2713        }
2714    }
2715}
2716
2717#[cfg(feature = "eval-support")]
2718pub trait EvalCache: Send + Sync {
2719    fn read(&self, key: EvalCacheKey) -> Option<String>;
2720    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2721}
2722
2723#[derive(Debug, Clone, Copy)]
2724pub enum DataCollectionChoice {
2725    NotAnswered,
2726    Enabled,
2727    Disabled,
2728}
2729
2730impl DataCollectionChoice {
2731    pub fn is_enabled(self) -> bool {
2732        match self {
2733            Self::Enabled => true,
2734            Self::NotAnswered | Self::Disabled => false,
2735        }
2736    }
2737
2738    pub fn is_answered(self) -> bool {
2739        match self {
2740            Self::Enabled | Self::Disabled => true,
2741            Self::NotAnswered => false,
2742        }
2743    }
2744
2745    #[must_use]
2746    pub fn toggle(&self) -> DataCollectionChoice {
2747        match self {
2748            Self::Enabled => Self::Disabled,
2749            Self::Disabled => Self::Enabled,
2750            Self::NotAnswered => Self::Enabled,
2751        }
2752    }
2753}
2754
2755impl From<bool> for DataCollectionChoice {
2756    fn from(value: bool) -> Self {
2757        match value {
2758            true => DataCollectionChoice::Enabled,
2759            false => DataCollectionChoice::Disabled,
2760        }
2761    }
2762}
2763
2764struct ZedPredictUpsell;
2765
2766impl Dismissable for ZedPredictUpsell {
2767    const KEY: &'static str = "dismissed-edit-predict-upsell";
2768
2769    fn dismissed() -> bool {
2770        // To make this backwards compatible with older versions of Zed, we
2771        // check if the user has seen the previous Edit Prediction Onboarding
2772        // before, by checking the data collection choice which was written to
2773        // the database once the user clicked on "Accept and Enable"
2774        if KEY_VALUE_STORE
2775            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2776            .log_err()
2777            .is_some_and(|s| s.is_some())
2778        {
2779            return true;
2780        }
2781
2782        KEY_VALUE_STORE
2783            .read_kvp(Self::KEY)
2784            .log_err()
2785            .is_some_and(|s| s.is_some())
2786    }
2787}
2788
2789pub fn should_show_upsell_modal() -> bool {
2790    !ZedPredictUpsell::dismissed()
2791}
2792
2793pub fn init(cx: &mut App) {
2794    feature_gate_predict_edits_actions(cx);
2795
2796    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2797        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2798            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2799                RatePredictionsModal::toggle(workspace, window, cx);
2800            }
2801        });
2802
2803        workspace.register_action(
2804            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2805                ZedPredictModal::toggle(
2806                    workspace,
2807                    workspace.user_store().clone(),
2808                    workspace.client().clone(),
2809                    window,
2810                    cx,
2811                )
2812            },
2813        );
2814
2815        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2816            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2817                settings
2818                    .project
2819                    .all_languages
2820                    .features
2821                    .get_or_insert_default()
2822                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2823            });
2824        });
2825    })
2826    .detach();
2827}
2828
2829fn feature_gate_predict_edits_actions(cx: &mut App) {
2830    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2831    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2832    let zeta_all_action_types = [
2833        TypeId::of::<RateCompletions>(),
2834        TypeId::of::<ResetOnboarding>(),
2835        zed_actions::OpenZedPredictOnboarding.type_id(),
2836        TypeId::of::<ClearHistory>(),
2837        TypeId::of::<ThumbsUpActivePrediction>(),
2838        TypeId::of::<ThumbsDownActivePrediction>(),
2839        TypeId::of::<NextEdit>(),
2840        TypeId::of::<PreviousEdit>(),
2841    ];
2842
2843    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2844        filter.hide_action_types(&rate_completion_action_types);
2845        filter.hide_action_types(&reset_onboarding_action_types);
2846        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2847    });
2848
2849    cx.observe_global::<SettingsStore>(move |cx| {
2850        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2851        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2852
2853        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2854            if is_ai_disabled {
2855                filter.hide_action_types(&zeta_all_action_types);
2856            } else if has_feature_flag {
2857                filter.show_action_types(&rate_completion_action_types);
2858            } else {
2859                filter.hide_action_types(&rate_completion_action_types);
2860            }
2861        });
2862    })
2863    .detach();
2864
2865    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2866        if !DisableAiSettings::get_global(cx).disable_ai {
2867            if is_enabled {
2868                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2869                    filter.show_action_types(&rate_completion_action_types);
2870                });
2871            } else {
2872                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2873                    filter.hide_action_types(&rate_completion_action_types);
2874                });
2875            }
2876        }
2877    })
2878    .detach();
2879}
2880
2881#[cfg(test)]
2882mod tests {
2883    use std::{path::Path, sync::Arc, time::Duration};
2884
2885    use client::UserStore;
2886    use clock::FakeSystemClock;
2887    use cloud_llm_client::{
2888        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2889    };
2890    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2891    use futures::{
2892        AsyncReadExt, StreamExt,
2893        channel::{mpsc, oneshot},
2894    };
2895    use gpui::{
2896        Entity, TestAppContext,
2897        http_client::{FakeHttpClient, Response},
2898        prelude::*,
2899    };
2900    use indoc::indoc;
2901    use language::OffsetRangeExt as _;
2902    use open_ai::Usage;
2903    use pretty_assertions::{assert_eq, assert_matches};
2904    use project::{FakeFs, Project};
2905    use serde_json::json;
2906    use settings::SettingsStore;
2907    use util::path;
2908    use uuid::Uuid;
2909
2910    use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
2911
2912    #[gpui::test]
2913    async fn test_current_state(cx: &mut TestAppContext) {
2914        let (zeta, mut requests) = init_test(cx);
2915        let fs = FakeFs::new(cx.executor());
2916        fs.insert_tree(
2917            "/root",
2918            json!({
2919                "1.txt": "Hello!\nHow\nBye\n",
2920                "2.txt": "Hola!\nComo\nAdios\n"
2921            }),
2922        )
2923        .await;
2924        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2925
2926        zeta.update(cx, |zeta, cx| {
2927            zeta.register_project(&project, cx);
2928        });
2929
2930        let buffer1 = project
2931            .update(cx, |project, cx| {
2932                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2933                project.open_buffer(path, cx)
2934            })
2935            .await
2936            .unwrap();
2937        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2938        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2939
2940        // Prediction for current file
2941
2942        zeta.update(cx, |zeta, cx| {
2943            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2944        });
2945        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2946
2947        respond_tx
2948            .send(model_response(indoc! {r"
2949                --- a/root/1.txt
2950                +++ b/root/1.txt
2951                @@ ... @@
2952                 Hello!
2953                -How
2954                +How are you?
2955                 Bye
2956            "}))
2957            .unwrap();
2958
2959        cx.run_until_parked();
2960
2961        zeta.read_with(cx, |zeta, cx| {
2962            let prediction = zeta
2963                .current_prediction_for_buffer(&buffer1, &project, cx)
2964                .unwrap();
2965            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2966        });
2967
2968        // Context refresh
2969        let refresh_task = zeta.update(cx, |zeta, cx| {
2970            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2971        });
2972        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2973        respond_tx
2974            .send(open_ai::Response {
2975                id: Uuid::new_v4().to_string(),
2976                object: "response".into(),
2977                created: 0,
2978                model: "model".into(),
2979                choices: vec![open_ai::Choice {
2980                    index: 0,
2981                    message: open_ai::RequestMessage::Assistant {
2982                        content: None,
2983                        tool_calls: vec![open_ai::ToolCall {
2984                            id: "search".into(),
2985                            content: open_ai::ToolCallContent::Function {
2986                                function: open_ai::FunctionContent {
2987                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2988                                        .to_string(),
2989                                    arguments: serde_json::to_string(&SearchToolInput {
2990                                        queries: Box::new([SearchToolQuery {
2991                                            glob: "root/2.txt".to_string(),
2992                                            syntax_node: vec![],
2993                                            content: Some(".".into()),
2994                                        }]),
2995                                    })
2996                                    .unwrap(),
2997                                },
2998                            },
2999                        }],
3000                    },
3001                    finish_reason: None,
3002                }],
3003                usage: Usage {
3004                    prompt_tokens: 0,
3005                    completion_tokens: 0,
3006                    total_tokens: 0,
3007                },
3008            })
3009            .unwrap();
3010        refresh_task.await.unwrap();
3011
3012        zeta.update(cx, |zeta, _cx| {
3013            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
3014        });
3015
3016        // Prediction for another file
3017        zeta.update(cx, |zeta, cx| {
3018            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3019        });
3020        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3021        respond_tx
3022            .send(model_response(indoc! {r#"
3023                --- a/root/2.txt
3024                +++ b/root/2.txt
3025                 Hola!
3026                -Como
3027                +Como estas?
3028                 Adios
3029            "#}))
3030            .unwrap();
3031        cx.run_until_parked();
3032
3033        zeta.read_with(cx, |zeta, cx| {
3034            let prediction = zeta
3035                .current_prediction_for_buffer(&buffer1, &project, cx)
3036                .unwrap();
3037            assert_matches!(
3038                prediction,
3039                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3040            );
3041        });
3042
3043        let buffer2 = project
3044            .update(cx, |project, cx| {
3045                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3046                project.open_buffer(path, cx)
3047            })
3048            .await
3049            .unwrap();
3050
3051        zeta.read_with(cx, |zeta, cx| {
3052            let prediction = zeta
3053                .current_prediction_for_buffer(&buffer2, &project, cx)
3054                .unwrap();
3055            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3056        });
3057    }
3058
3059    #[gpui::test]
3060    async fn test_simple_request(cx: &mut TestAppContext) {
3061        let (zeta, mut requests) = init_test(cx);
3062        let fs = FakeFs::new(cx.executor());
3063        fs.insert_tree(
3064            "/root",
3065            json!({
3066                "foo.md":  "Hello!\nHow\nBye\n"
3067            }),
3068        )
3069        .await;
3070        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3071
3072        let buffer = project
3073            .update(cx, |project, cx| {
3074                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3075                project.open_buffer(path, cx)
3076            })
3077            .await
3078            .unwrap();
3079        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3080        let position = snapshot.anchor_before(language::Point::new(1, 3));
3081
3082        let prediction_task = zeta.update(cx, |zeta, cx| {
3083            zeta.request_prediction(&project, &buffer, position, cx)
3084        });
3085
3086        let (_, respond_tx) = requests.predict.next().await.unwrap();
3087
3088        // TODO Put back when we have a structured request again
3089        // assert_eq!(
3090        //     request.excerpt_path.as_ref(),
3091        //     Path::new(path!("root/foo.md"))
3092        // );
3093        // assert_eq!(
3094        //     request.cursor_point,
3095        //     Point {
3096        //         line: Line(1),
3097        //         column: 3
3098        //     }
3099        // );
3100
3101        respond_tx
3102            .send(model_response(indoc! { r"
3103                --- a/root/foo.md
3104                +++ b/root/foo.md
3105                @@ ... @@
3106                 Hello!
3107                -How
3108                +How are you?
3109                 Bye
3110            "}))
3111            .unwrap();
3112
3113        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3114
3115        assert_eq!(prediction.edits.len(), 1);
3116        assert_eq!(
3117            prediction.edits[0].0.to_point(&snapshot).start,
3118            language::Point::new(1, 3)
3119        );
3120        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3121    }
3122
3123    #[gpui::test]
3124    async fn test_request_events(cx: &mut TestAppContext) {
3125        let (zeta, mut requests) = init_test(cx);
3126        let fs = FakeFs::new(cx.executor());
3127        fs.insert_tree(
3128            "/root",
3129            json!({
3130                "foo.md": "Hello!\n\nBye\n"
3131            }),
3132        )
3133        .await;
3134        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3135
3136        let buffer = project
3137            .update(cx, |project, cx| {
3138                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3139                project.open_buffer(path, cx)
3140            })
3141            .await
3142            .unwrap();
3143
3144        zeta.update(cx, |zeta, cx| {
3145            zeta.register_buffer(&buffer, &project, cx);
3146        });
3147
3148        buffer.update(cx, |buffer, cx| {
3149            buffer.edit(vec![(7..7, "How")], None, cx);
3150        });
3151
3152        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3153        let position = snapshot.anchor_before(language::Point::new(1, 3));
3154
3155        let prediction_task = zeta.update(cx, |zeta, cx| {
3156            zeta.request_prediction(&project, &buffer, position, cx)
3157        });
3158
3159        let (request, respond_tx) = requests.predict.next().await.unwrap();
3160
3161        let prompt = prompt_from_request(&request);
3162        assert!(
3163            prompt.contains(indoc! {"
3164            --- a/root/foo.md
3165            +++ b/root/foo.md
3166            @@ -1,3 +1,3 @@
3167             Hello!
3168            -
3169            +How
3170             Bye
3171        "}),
3172            "{prompt}"
3173        );
3174
3175        respond_tx
3176            .send(model_response(indoc! {r#"
3177                --- a/root/foo.md
3178                +++ b/root/foo.md
3179                @@ ... @@
3180                 Hello!
3181                -How
3182                +How are you?
3183                 Bye
3184            "#}))
3185            .unwrap();
3186
3187        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3188
3189        assert_eq!(prediction.edits.len(), 1);
3190        assert_eq!(
3191            prediction.edits[0].0.to_point(&snapshot).start,
3192            language::Point::new(1, 3)
3193        );
3194        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3195    }
3196
3197    #[gpui::test]
3198    async fn test_empty_prediction(cx: &mut TestAppContext) {
3199        let (zeta, mut requests) = init_test(cx);
3200        let fs = FakeFs::new(cx.executor());
3201        fs.insert_tree(
3202            "/root",
3203            json!({
3204                "foo.md":  "Hello!\nHow\nBye\n"
3205            }),
3206        )
3207        .await;
3208        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3209
3210        let buffer = project
3211            .update(cx, |project, cx| {
3212                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3213                project.open_buffer(path, cx)
3214            })
3215            .await
3216            .unwrap();
3217        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3218        let position = snapshot.anchor_before(language::Point::new(1, 3));
3219
3220        zeta.update(cx, |zeta, cx| {
3221            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3222        });
3223
3224        const NO_OP_DIFF: &str = indoc! { r"
3225            --- a/root/foo.md
3226            +++ b/root/foo.md
3227            @@ ... @@
3228             Hello!
3229            -How
3230            +How
3231             Bye
3232        "};
3233
3234        let (_, respond_tx) = requests.predict.next().await.unwrap();
3235        let response = model_response(NO_OP_DIFF);
3236        let id = response.id.clone();
3237        respond_tx.send(response).unwrap();
3238
3239        cx.run_until_parked();
3240
3241        zeta.read_with(cx, |zeta, cx| {
3242            assert!(
3243                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3244                    .is_none()
3245            );
3246        });
3247
3248        // prediction is reported as rejected
3249        let (reject_request, _) = requests.reject.next().await.unwrap();
3250
3251        assert_eq!(
3252            &reject_request.rejections,
3253            &[EditPredictionRejection {
3254                request_id: id,
3255                reason: EditPredictionRejectReason::Empty,
3256                was_shown: false
3257            }]
3258        );
3259    }
3260
3261    #[gpui::test]
3262    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3263        let (zeta, mut requests) = init_test(cx);
3264        let fs = FakeFs::new(cx.executor());
3265        fs.insert_tree(
3266            "/root",
3267            json!({
3268                "foo.md":  "Hello!\nHow\nBye\n"
3269            }),
3270        )
3271        .await;
3272        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3273
3274        let buffer = project
3275            .update(cx, |project, cx| {
3276                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3277                project.open_buffer(path, cx)
3278            })
3279            .await
3280            .unwrap();
3281        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3282        let position = snapshot.anchor_before(language::Point::new(1, 3));
3283
3284        zeta.update(cx, |zeta, cx| {
3285            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3286        });
3287
3288        let (_, respond_tx) = requests.predict.next().await.unwrap();
3289
3290        buffer.update(cx, |buffer, cx| {
3291            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3292        });
3293
3294        let response = model_response(SIMPLE_DIFF);
3295        let id = response.id.clone();
3296        respond_tx.send(response).unwrap();
3297
3298        cx.run_until_parked();
3299
3300        zeta.read_with(cx, |zeta, cx| {
3301            assert!(
3302                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3303                    .is_none()
3304            );
3305        });
3306
3307        // prediction is reported as rejected
3308        let (reject_request, _) = requests.reject.next().await.unwrap();
3309
3310        assert_eq!(
3311            &reject_request.rejections,
3312            &[EditPredictionRejection {
3313                request_id: id,
3314                reason: EditPredictionRejectReason::InterpolatedEmpty,
3315                was_shown: false
3316            }]
3317        );
3318    }
3319
3320    const SIMPLE_DIFF: &str = indoc! { r"
3321        --- a/root/foo.md
3322        +++ b/root/foo.md
3323        @@ ... @@
3324         Hello!
3325        -How
3326        +How are you?
3327         Bye
3328    "};
3329
3330    #[gpui::test]
3331    async fn test_replace_current(cx: &mut TestAppContext) {
3332        let (zeta, mut requests) = init_test(cx);
3333        let fs = FakeFs::new(cx.executor());
3334        fs.insert_tree(
3335            "/root",
3336            json!({
3337                "foo.md":  "Hello!\nHow\nBye\n"
3338            }),
3339        )
3340        .await;
3341        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3342
3343        let buffer = project
3344            .update(cx, |project, cx| {
3345                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3346                project.open_buffer(path, cx)
3347            })
3348            .await
3349            .unwrap();
3350        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3351        let position = snapshot.anchor_before(language::Point::new(1, 3));
3352
3353        zeta.update(cx, |zeta, cx| {
3354            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3355        });
3356
3357        let (_, respond_tx) = requests.predict.next().await.unwrap();
3358        let first_response = model_response(SIMPLE_DIFF);
3359        let first_id = first_response.id.clone();
3360        respond_tx.send(first_response).unwrap();
3361
3362        cx.run_until_parked();
3363
3364        zeta.read_with(cx, |zeta, cx| {
3365            assert_eq!(
3366                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3367                    .unwrap()
3368                    .id
3369                    .0,
3370                first_id
3371            );
3372        });
3373
3374        // a second request is triggered
3375        zeta.update(cx, |zeta, cx| {
3376            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3377        });
3378
3379        let (_, respond_tx) = requests.predict.next().await.unwrap();
3380        let second_response = model_response(SIMPLE_DIFF);
3381        let second_id = second_response.id.clone();
3382        respond_tx.send(second_response).unwrap();
3383
3384        cx.run_until_parked();
3385
3386        zeta.read_with(cx, |zeta, cx| {
3387            // second replaces first
3388            assert_eq!(
3389                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3390                    .unwrap()
3391                    .id
3392                    .0,
3393                second_id
3394            );
3395        });
3396
3397        // first is reported as replaced
3398        let (reject_request, _) = requests.reject.next().await.unwrap();
3399
3400        assert_eq!(
3401            &reject_request.rejections,
3402            &[EditPredictionRejection {
3403                request_id: first_id,
3404                reason: EditPredictionRejectReason::Replaced,
3405                was_shown: false
3406            }]
3407        );
3408    }
3409
3410    #[gpui::test]
3411    async fn test_current_preferred(cx: &mut TestAppContext) {
3412        let (zeta, mut requests) = init_test(cx);
3413        let fs = FakeFs::new(cx.executor());
3414        fs.insert_tree(
3415            "/root",
3416            json!({
3417                "foo.md":  "Hello!\nHow\nBye\n"
3418            }),
3419        )
3420        .await;
3421        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3422
3423        let buffer = project
3424            .update(cx, |project, cx| {
3425                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3426                project.open_buffer(path, cx)
3427            })
3428            .await
3429            .unwrap();
3430        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3431        let position = snapshot.anchor_before(language::Point::new(1, 3));
3432
3433        zeta.update(cx, |zeta, cx| {
3434            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3435        });
3436
3437        let (_, respond_tx) = requests.predict.next().await.unwrap();
3438        let first_response = model_response(SIMPLE_DIFF);
3439        let first_id = first_response.id.clone();
3440        respond_tx.send(first_response).unwrap();
3441
3442        cx.run_until_parked();
3443
3444        zeta.read_with(cx, |zeta, cx| {
3445            assert_eq!(
3446                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3447                    .unwrap()
3448                    .id
3449                    .0,
3450                first_id
3451            );
3452        });
3453
3454        // a second request is triggered
3455        zeta.update(cx, |zeta, cx| {
3456            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3457        });
3458
3459        let (_, respond_tx) = requests.predict.next().await.unwrap();
3460        // worse than current prediction
3461        let second_response = model_response(indoc! { r"
3462            --- a/root/foo.md
3463            +++ b/root/foo.md
3464            @@ ... @@
3465             Hello!
3466            -How
3467            +How are
3468             Bye
3469        "});
3470        let second_id = second_response.id.clone();
3471        respond_tx.send(second_response).unwrap();
3472
3473        cx.run_until_parked();
3474
3475        zeta.read_with(cx, |zeta, cx| {
3476            // first is preferred over second
3477            assert_eq!(
3478                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3479                    .unwrap()
3480                    .id
3481                    .0,
3482                first_id
3483            );
3484        });
3485
3486        // second is reported as rejected
3487        let (reject_request, _) = requests.reject.next().await.unwrap();
3488
3489        assert_eq!(
3490            &reject_request.rejections,
3491            &[EditPredictionRejection {
3492                request_id: second_id,
3493                reason: EditPredictionRejectReason::CurrentPreferred,
3494                was_shown: false
3495            }]
3496        );
3497    }
3498
3499    #[gpui::test]
3500    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3501        let (zeta, mut requests) = init_test(cx);
3502        let fs = FakeFs::new(cx.executor());
3503        fs.insert_tree(
3504            "/root",
3505            json!({
3506                "foo.md":  "Hello!\nHow\nBye\n"
3507            }),
3508        )
3509        .await;
3510        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3511
3512        let buffer = project
3513            .update(cx, |project, cx| {
3514                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3515                project.open_buffer(path, cx)
3516            })
3517            .await
3518            .unwrap();
3519        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3520        let position = snapshot.anchor_before(language::Point::new(1, 3));
3521
3522        // start two refresh tasks
3523        zeta.update(cx, |zeta, cx| {
3524            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3525        });
3526
3527        let (_, respond_first) = requests.predict.next().await.unwrap();
3528
3529        zeta.update(cx, |zeta, cx| {
3530            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3531        });
3532
3533        let (_, respond_second) = requests.predict.next().await.unwrap();
3534
3535        // wait for throttle
3536        cx.run_until_parked();
3537
3538        // second responds first
3539        let second_response = model_response(SIMPLE_DIFF);
3540        let second_id = second_response.id.clone();
3541        respond_second.send(second_response).unwrap();
3542
3543        cx.run_until_parked();
3544
3545        zeta.read_with(cx, |zeta, cx| {
3546            // current prediction is second
3547            assert_eq!(
3548                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3549                    .unwrap()
3550                    .id
3551                    .0,
3552                second_id
3553            );
3554        });
3555
3556        let first_response = model_response(SIMPLE_DIFF);
3557        let first_id = first_response.id.clone();
3558        respond_first.send(first_response).unwrap();
3559
3560        cx.run_until_parked();
3561
3562        zeta.read_with(cx, |zeta, cx| {
3563            // current prediction is still second, since first was cancelled
3564            assert_eq!(
3565                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3566                    .unwrap()
3567                    .id
3568                    .0,
3569                second_id
3570            );
3571        });
3572
3573        // first is reported as rejected
3574        let (reject_request, _) = requests.reject.next().await.unwrap();
3575
3576        cx.run_until_parked();
3577
3578        assert_eq!(
3579            &reject_request.rejections,
3580            &[EditPredictionRejection {
3581                request_id: first_id,
3582                reason: EditPredictionRejectReason::Canceled,
3583                was_shown: false
3584            }]
3585        );
3586    }
3587
3588    #[gpui::test]
3589    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3590        let (zeta, mut requests) = init_test(cx);
3591        let fs = FakeFs::new(cx.executor());
3592        fs.insert_tree(
3593            "/root",
3594            json!({
3595                "foo.md":  "Hello!\nHow\nBye\n"
3596            }),
3597        )
3598        .await;
3599        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3600
3601        let buffer = project
3602            .update(cx, |project, cx| {
3603                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3604                project.open_buffer(path, cx)
3605            })
3606            .await
3607            .unwrap();
3608        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3609        let position = snapshot.anchor_before(language::Point::new(1, 3));
3610
3611        // start two refresh tasks
3612        zeta.update(cx, |zeta, cx| {
3613            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3614        });
3615
3616        let (_, respond_first) = requests.predict.next().await.unwrap();
3617
3618        zeta.update(cx, |zeta, cx| {
3619            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3620        });
3621
3622        let (_, respond_second) = requests.predict.next().await.unwrap();
3623
3624        // wait for throttle, so requests are sent
3625        cx.run_until_parked();
3626
3627        zeta.update(cx, |zeta, cx| {
3628            // start a third request
3629            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3630
3631            // 2 are pending, so 2nd is cancelled
3632            assert_eq!(
3633                zeta.get_or_init_zeta_project(&project, cx)
3634                    .cancelled_predictions
3635                    .iter()
3636                    .copied()
3637                    .collect::<Vec<_>>(),
3638                [1]
3639            );
3640        });
3641
3642        // wait for throttle
3643        cx.run_until_parked();
3644
3645        let (_, respond_third) = requests.predict.next().await.unwrap();
3646
3647        let first_response = model_response(SIMPLE_DIFF);
3648        let first_id = first_response.id.clone();
3649        respond_first.send(first_response).unwrap();
3650
3651        cx.run_until_parked();
3652
3653        zeta.read_with(cx, |zeta, cx| {
3654            // current prediction is first
3655            assert_eq!(
3656                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3657                    .unwrap()
3658                    .id
3659                    .0,
3660                first_id
3661            );
3662        });
3663
3664        let cancelled_response = model_response(SIMPLE_DIFF);
3665        let cancelled_id = cancelled_response.id.clone();
3666        respond_second.send(cancelled_response).unwrap();
3667
3668        cx.run_until_parked();
3669
3670        zeta.read_with(cx, |zeta, cx| {
3671            // current prediction is still first, since second was cancelled
3672            assert_eq!(
3673                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3674                    .unwrap()
3675                    .id
3676                    .0,
3677                first_id
3678            );
3679        });
3680
3681        let third_response = model_response(SIMPLE_DIFF);
3682        let third_response_id = third_response.id.clone();
3683        respond_third.send(third_response).unwrap();
3684
3685        cx.run_until_parked();
3686
3687        zeta.read_with(cx, |zeta, cx| {
3688            // third completes and replaces first
3689            assert_eq!(
3690                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3691                    .unwrap()
3692                    .id
3693                    .0,
3694                third_response_id
3695            );
3696        });
3697
3698        // second is reported as rejected
3699        let (reject_request, _) = requests.reject.next().await.unwrap();
3700
3701        cx.run_until_parked();
3702
3703        assert_eq!(
3704            &reject_request.rejections,
3705            &[
3706                EditPredictionRejection {
3707                    request_id: cancelled_id,
3708                    reason: EditPredictionRejectReason::Canceled,
3709                    was_shown: false
3710                },
3711                EditPredictionRejection {
3712                    request_id: first_id,
3713                    reason: EditPredictionRejectReason::Replaced,
3714                    was_shown: false
3715                }
3716            ]
3717        );
3718    }
3719
3720    #[gpui::test]
3721    async fn test_rejections_flushing(cx: &mut TestAppContext) {
3722        let (zeta, mut requests) = init_test(cx);
3723
3724        zeta.update(cx, |zeta, _cx| {
3725            zeta.reject_prediction(
3726                EditPredictionId("test-1".into()),
3727                EditPredictionRejectReason::Discarded,
3728                false,
3729            );
3730            zeta.reject_prediction(
3731                EditPredictionId("test-2".into()),
3732                EditPredictionRejectReason::Canceled,
3733                true,
3734            );
3735        });
3736
3737        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3738        cx.run_until_parked();
3739
3740        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3741        respond_tx.send(()).unwrap();
3742
3743        // batched
3744        assert_eq!(reject_request.rejections.len(), 2);
3745        assert_eq!(
3746            reject_request.rejections[0],
3747            EditPredictionRejection {
3748                request_id: "test-1".to_string(),
3749                reason: EditPredictionRejectReason::Discarded,
3750                was_shown: false
3751            }
3752        );
3753        assert_eq!(
3754            reject_request.rejections[1],
3755            EditPredictionRejection {
3756                request_id: "test-2".to_string(),
3757                reason: EditPredictionRejectReason::Canceled,
3758                was_shown: true
3759            }
3760        );
3761
3762        // Reaching batch size limit sends without debounce
3763        zeta.update(cx, |zeta, _cx| {
3764            for i in 0..70 {
3765                zeta.reject_prediction(
3766                    EditPredictionId(format!("batch-{}", i).into()),
3767                    EditPredictionRejectReason::Discarded,
3768                    false,
3769                );
3770            }
3771        });
3772
3773        // First MAX/2 items are sent immediately
3774        cx.run_until_parked();
3775        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3776        respond_tx.send(()).unwrap();
3777
3778        assert_eq!(reject_request.rejections.len(), 50);
3779        assert_eq!(reject_request.rejections[0].request_id, "batch-0");
3780        assert_eq!(reject_request.rejections[49].request_id, "batch-49");
3781
3782        // Remaining items are debounced with the next batch
3783        cx.executor().advance_clock(Duration::from_secs(15));
3784        cx.run_until_parked();
3785
3786        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3787        respond_tx.send(()).unwrap();
3788
3789        assert_eq!(reject_request.rejections.len(), 20);
3790        assert_eq!(reject_request.rejections[0].request_id, "batch-50");
3791        assert_eq!(reject_request.rejections[19].request_id, "batch-69");
3792
3793        // Request failure
3794        zeta.update(cx, |zeta, _cx| {
3795            zeta.reject_prediction(
3796                EditPredictionId("retry-1".into()),
3797                EditPredictionRejectReason::Discarded,
3798                false,
3799            );
3800        });
3801
3802        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3803        cx.run_until_parked();
3804
3805        let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
3806        assert_eq!(reject_request.rejections.len(), 1);
3807        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3808        // Simulate failure
3809        drop(_respond_tx);
3810
3811        // Add another rejection
3812        zeta.update(cx, |zeta, _cx| {
3813            zeta.reject_prediction(
3814                EditPredictionId("retry-2".into()),
3815                EditPredictionRejectReason::Discarded,
3816                false,
3817            );
3818        });
3819
3820        cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3821        cx.run_until_parked();
3822
3823        // Retry should include both the failed item and the new one
3824        let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3825        respond_tx.send(()).unwrap();
3826
3827        assert_eq!(reject_request.rejections.len(), 2);
3828        assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3829        assert_eq!(reject_request.rejections[1].request_id, "retry-2");
3830    }
3831
3832    // Skipped until we start including diagnostics in prompt
3833    // #[gpui::test]
3834    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3835    //     let (zeta, mut req_rx) = init_test(cx);
3836    //     let fs = FakeFs::new(cx.executor());
3837    //     fs.insert_tree(
3838    //         "/root",
3839    //         json!({
3840    //             "foo.md": "Hello!\nBye"
3841    //         }),
3842    //     )
3843    //     .await;
3844    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3845
3846    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3847    //     let diagnostic = lsp::Diagnostic {
3848    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3849    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3850    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3851    //         ..Default::default()
3852    //     };
3853
3854    //     project.update(cx, |project, cx| {
3855    //         project.lsp_store().update(cx, |lsp_store, cx| {
3856    //             // Create some diagnostics
3857    //             lsp_store
3858    //                 .update_diagnostics(
3859    //                     LanguageServerId(0),
3860    //                     lsp::PublishDiagnosticsParams {
3861    //                         uri: path_to_buffer_uri.clone(),
3862    //                         diagnostics: vec![diagnostic],
3863    //                         version: None,
3864    //                     },
3865    //                     None,
3866    //                     language::DiagnosticSourceKind::Pushed,
3867    //                     &[],
3868    //                     cx,
3869    //                 )
3870    //                 .unwrap();
3871    //         });
3872    //     });
3873
3874    //     let buffer = project
3875    //         .update(cx, |project, cx| {
3876    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3877    //             project.open_buffer(path, cx)
3878    //         })
3879    //         .await
3880    //         .unwrap();
3881
3882    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3883    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3884
3885    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3886    //         zeta.request_prediction(&project, &buffer, position, cx)
3887    //     });
3888
3889    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3890
3891    //     assert_eq!(request.diagnostic_groups.len(), 1);
3892    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3893    //         .unwrap();
3894    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3895    //     assert_eq!(
3896    //         value,
3897    //         json!({
3898    //             "entries": [{
3899    //                 "range": {
3900    //                     "start": 8,
3901    //                     "end": 10
3902    //                 },
3903    //                 "diagnostic": {
3904    //                     "source": null,
3905    //                     "code": null,
3906    //                     "code_description": null,
3907    //                     "severity": 1,
3908    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3909    //                     "markdown": null,
3910    //                     "group_id": 0,
3911    //                     "is_primary": true,
3912    //                     "is_disk_based": false,
3913    //                     "is_unnecessary": false,
3914    //                     "source_kind": "Pushed",
3915    //                     "data": null,
3916    //                     "underline": true
3917    //                 }
3918    //             }],
3919    //             "primary_ix": 0
3920    //         })
3921    //     );
3922    // }
3923
3924    fn model_response(text: &str) -> open_ai::Response {
3925        open_ai::Response {
3926            id: Uuid::new_v4().to_string(),
3927            object: "response".into(),
3928            created: 0,
3929            model: "model".into(),
3930            choices: vec![open_ai::Choice {
3931                index: 0,
3932                message: open_ai::RequestMessage::Assistant {
3933                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3934                    tool_calls: vec![],
3935                },
3936                finish_reason: None,
3937            }],
3938            usage: Usage {
3939                prompt_tokens: 0,
3940                completion_tokens: 0,
3941                total_tokens: 0,
3942            },
3943        }
3944    }
3945
3946    fn prompt_from_request(request: &open_ai::Request) -> &str {
3947        assert_eq!(request.messages.len(), 1);
3948        let open_ai::RequestMessage::User {
3949            content: open_ai::MessageContent::Plain(content),
3950            ..
3951        } = &request.messages[0]
3952        else {
3953            panic!(
3954                "Request does not have single user message of type Plain. {:#?}",
3955                request
3956            );
3957        };
3958        content
3959    }
3960
3961    struct RequestChannels {
3962        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3963        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3964    }
3965
3966    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3967        cx.update(move |cx| {
3968            let settings_store = SettingsStore::test(cx);
3969            cx.set_global(settings_store);
3970            zlog::init_test();
3971
3972            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3973            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3974
3975            let http_client = FakeHttpClient::create({
3976                move |req| {
3977                    let uri = req.uri().path().to_string();
3978                    let mut body = req.into_body();
3979                    let predict_req_tx = predict_req_tx.clone();
3980                    let reject_req_tx = reject_req_tx.clone();
3981                    async move {
3982                        let resp = match uri.as_str() {
3983                            "/client/llm_tokens" => serde_json::to_string(&json!({
3984                                "token": "test"
3985                            }))
3986                            .unwrap(),
3987                            "/predict_edits/raw" => {
3988                                let mut buf = Vec::new();
3989                                body.read_to_end(&mut buf).await.ok();
3990                                let req = serde_json::from_slice(&buf).unwrap();
3991
3992                                let (res_tx, res_rx) = oneshot::channel();
3993                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3994                                serde_json::to_string(&res_rx.await?).unwrap()
3995                            }
3996                            "/predict_edits/reject" => {
3997                                let mut buf = Vec::new();
3998                                body.read_to_end(&mut buf).await.ok();
3999                                let req = serde_json::from_slice(&buf).unwrap();
4000
4001                                let (res_tx, res_rx) = oneshot::channel();
4002                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
4003                                serde_json::to_string(&res_rx.await?).unwrap()
4004                            }
4005                            _ => {
4006                                panic!("Unexpected path: {}", uri)
4007                            }
4008                        };
4009
4010                        Ok(Response::builder().body(resp.into()).unwrap())
4011                    }
4012                }
4013            });
4014
4015            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
4016            client.cloud_client().set_credentials(1, "test".into());
4017
4018            language_model::init(client.clone(), cx);
4019
4020            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
4021            let zeta = Zeta::global(&client, &user_store, cx);
4022
4023            (
4024                zeta,
4025                RequestChannels {
4026                    predict: predict_req_rx,
4027                    reject: reject_req_rx,
4028                },
4029            )
4030        })
4031    }
4032}