zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  13use collections::{HashMap, HashSet};
  14use command_palette_hooks::CommandPaletteFilter;
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{
  17    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  18    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  19    SyntaxIndex, SyntaxIndexState,
  20};
  21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  22use futures::channel::{mpsc, oneshot};
  23use futures::{AsyncReadExt as _, StreamExt as _};
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::{
  30    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  31};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, RefreshLlmTokenListener};
  34use open_ai::FunctionDefinition;
  35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  40use std::any::{Any as _, TypeId};
  41use std::collections::{VecDeque, hash_map};
  42use telemetry_events::EditPredictionRating;
  43use workspace::Workspace;
  44
  45use std::ops::Range;
  46use std::path::Path;
  47use std::rc::Rc;
  48use std::str::FromStr as _;
  49use std::sync::{Arc, LazyLock};
  50use std::time::{Duration, Instant};
  51use std::{env, mem};
  52use thiserror::Error;
  53use util::rel_path::RelPathBuf;
  54use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  56
  57pub mod assemble_excerpts;
  58mod license_detection;
  59mod onboarding_modal;
  60mod prediction;
  61mod provider;
  62mod rate_prediction_modal;
  63pub mod retrieval_search;
  64mod sweep_ai;
  65pub mod udiff;
  66mod xml_edits;
  67pub mod zeta1;
  68
  69#[cfg(test)]
  70mod zeta_tests;
  71
  72use crate::assemble_excerpts::assemble_excerpts;
  73use crate::license_detection::LicenseDetectionWatcher;
  74use crate::onboarding_modal::ZedPredictModal;
  75pub use crate::prediction::EditPrediction;
  76pub use crate::prediction::EditPredictionId;
  77pub use crate::prediction::EditPredictionInputs;
  78use crate::prediction::EditPredictionResult;
  79use crate::rate_prediction_modal::{
  80    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  81    ThumbsUpActivePrediction,
  82};
  83use crate::sweep_ai::SweepAi;
  84use crate::zeta1::request_prediction_with_zeta1;
  85pub use provider::ZetaEditPredictionProvider;
  86
  87actions!(
  88    edit_prediction,
  89    [
  90        /// Resets the edit prediction onboarding state.
  91        ResetOnboarding,
  92        /// Opens the rate completions modal.
  93        RateCompletions,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 103
 104pub struct SweepFeatureFlag;
 105
 106impl FeatureFlag for SweepFeatureFlag {
 107    const NAME: &str = "sweep-ai";
 108}
 109pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 110    max_bytes: 512,
 111    min_bytes: 128,
 112    target_before_cursor_over_total_bytes: 0.5,
 113};
 114
 115pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 116    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 117
 118pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 119    excerpt: DEFAULT_EXCERPT_OPTIONS,
 120};
 121
 122pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 123    EditPredictionContextOptions {
 124        use_imports: true,
 125        max_retrieved_declarations: 0,
 126        excerpt: DEFAULT_EXCERPT_OPTIONS,
 127        score: EditPredictionScoreOptions {
 128            omit_excerpt_overlaps: true,
 129        },
 130    };
 131
 132pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 133    context: DEFAULT_CONTEXT_OPTIONS,
 134    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 135    max_diagnostic_bytes: 2048,
 136    prompt_format: PromptFormat::DEFAULT,
 137    file_indexing_parallelism: 1,
 138    buffer_change_grouping_interval: Duration::from_secs(1),
 139};
 140
 141static USE_OLLAMA: LazyLock<bool> =
 142    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 143static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 144    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 145        "qwen3-coder:30b".to_string()
 146    } else {
 147        "yqvev8r3".to_string()
 148    })
 149});
 150static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 151    match env::var("ZED_ZETA2_MODEL").as_deref() {
 152        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 153        Ok(model) => model,
 154        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 155        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 156    }
 157    .to_string()
 158});
 159static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 160    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 161        if *USE_OLLAMA {
 162            Some("http://localhost:11434/v1/chat/completions".into())
 163        } else {
 164            None
 165        }
 166    })
 167});
 168
 169pub struct Zeta2FeatureFlag;
 170
 171impl FeatureFlag for Zeta2FeatureFlag {
 172    const NAME: &'static str = "zeta2";
 173
 174    fn enabled_for_staff() -> bool {
 175        true
 176    }
 177}
 178
 179#[derive(Clone)]
 180struct ZetaGlobal(Entity<Zeta>);
 181
 182impl Global for ZetaGlobal {}
 183
 184pub struct Zeta {
 185    client: Arc<Client>,
 186    user_store: Entity<UserStore>,
 187    llm_token: LlmApiToken,
 188    _llm_token_subscription: Subscription,
 189    projects: HashMap<EntityId, ZetaProject>,
 190    options: ZetaOptions,
 191    update_required: bool,
 192    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 193    #[cfg(feature = "eval-support")]
 194    eval_cache: Option<Arc<dyn EvalCache>>,
 195    edit_prediction_model: ZetaEditPredictionModel,
 196    sweep_ai: SweepAi,
 197    data_collection_choice: DataCollectionChoice,
 198    rejected_predictions: Vec<EditPredictionRejection>,
 199    reject_predictions_tx: mpsc::UnboundedSender<()>,
 200    reject_predictions_debounce_task: Option<Task<()>>,
 201    shown_predictions: VecDeque<EditPrediction>,
 202    rated_predictions: HashSet<EditPredictionId>,
 203}
 204
 205#[derive(Copy, Clone, Default, PartialEq, Eq)]
 206pub enum ZetaEditPredictionModel {
 207    #[default]
 208    Zeta1,
 209    Zeta2,
 210    Sweep,
 211}
 212
 213#[derive(Debug, Clone, PartialEq)]
 214pub struct ZetaOptions {
 215    pub context: ContextMode,
 216    pub max_prompt_bytes: usize,
 217    pub max_diagnostic_bytes: usize,
 218    pub prompt_format: predict_edits_v3::PromptFormat,
 219    pub file_indexing_parallelism: usize,
 220    pub buffer_change_grouping_interval: Duration,
 221}
 222
 223#[derive(Debug, Clone, PartialEq)]
 224pub enum ContextMode {
 225    Agentic(AgenticContextOptions),
 226    Syntax(EditPredictionContextOptions),
 227}
 228
 229#[derive(Debug, Clone, PartialEq)]
 230pub struct AgenticContextOptions {
 231    pub excerpt: EditPredictionExcerptOptions,
 232}
 233
 234impl ContextMode {
 235    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 236        match self {
 237            ContextMode::Agentic(options) => &options.excerpt,
 238            ContextMode::Syntax(options) => &options.excerpt,
 239        }
 240    }
 241}
 242
 243#[derive(Debug)]
 244pub enum ZetaDebugInfo {
 245    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 246    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 247    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 248    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 249    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 250}
 251
 252#[derive(Debug)]
 253pub struct ZetaContextRetrievalStartedDebugInfo {
 254    pub project: Entity<Project>,
 255    pub timestamp: Instant,
 256    pub search_prompt: String,
 257}
 258
 259#[derive(Debug)]
 260pub struct ZetaContextRetrievalDebugInfo {
 261    pub project: Entity<Project>,
 262    pub timestamp: Instant,
 263}
 264
 265#[derive(Debug)]
 266pub struct ZetaEditPredictionDebugInfo {
 267    pub inputs: EditPredictionInputs,
 268    pub retrieval_time: Duration,
 269    pub buffer: WeakEntity<Buffer>,
 270    pub position: language::Anchor,
 271    pub local_prompt: Result<String, String>,
 272    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 273}
 274
 275#[derive(Debug)]
 276pub struct ZetaSearchQueryDebugInfo {
 277    pub project: Entity<Project>,
 278    pub timestamp: Instant,
 279    pub search_queries: Vec<SearchToolQuery>,
 280}
 281
 282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 283
 284struct ZetaProject {
 285    syntax_index: Option<Entity<SyntaxIndex>>,
 286    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 287    last_event: Option<LastEvent>,
 288    recent_paths: VecDeque<ProjectPath>,
 289    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 290    current_prediction: Option<CurrentEditPrediction>,
 291    next_pending_prediction_id: usize,
 292    pending_predictions: ArrayVec<PendingPrediction, 2>,
 293    last_prediction_refresh: Option<(EntityId, Instant)>,
 294    cancelled_predictions: HashSet<usize>,
 295    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 296    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 297    refresh_context_debounce_task: Option<Task<Option<()>>>,
 298    refresh_context_timestamp: Option<Instant>,
 299    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 300    _subscription: gpui::Subscription,
 301}
 302
 303impl ZetaProject {
 304    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 305        self.events
 306            .iter()
 307            .cloned()
 308            .chain(
 309                self.last_event
 310                    .as_ref()
 311                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 312            )
 313            .collect()
 314    }
 315
 316    fn cancel_pending_prediction(
 317        &mut self,
 318        pending_prediction: PendingPrediction,
 319        cx: &mut Context<Zeta>,
 320    ) {
 321        self.cancelled_predictions.insert(pending_prediction.id);
 322
 323        cx.spawn(async move |this, cx| {
 324            let Some(prediction_id) = pending_prediction.task.await else {
 325                return;
 326            };
 327
 328            this.update(cx, |this, cx| {
 329                this.reject_prediction(
 330                    prediction_id,
 331                    EditPredictionRejectReason::Canceled,
 332                    false,
 333                    cx,
 334                );
 335            })
 336            .ok();
 337        })
 338        .detach()
 339    }
 340}
 341
 342#[derive(Debug, Clone)]
 343struct CurrentEditPrediction {
 344    pub requested_by: PredictionRequestedBy,
 345    pub prediction: EditPrediction,
 346    pub was_shown: bool,
 347}
 348
 349impl CurrentEditPrediction {
 350    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 351        let Some(new_edits) = self
 352            .prediction
 353            .interpolate(&self.prediction.buffer.read(cx))
 354        else {
 355            return false;
 356        };
 357
 358        if self.prediction.buffer != old_prediction.prediction.buffer {
 359            return true;
 360        }
 361
 362        let Some(old_edits) = old_prediction
 363            .prediction
 364            .interpolate(&old_prediction.prediction.buffer.read(cx))
 365        else {
 366            return true;
 367        };
 368
 369        let requested_by_buffer_id = self.requested_by.buffer_id();
 370
 371        // This reduces the occurrence of UI thrash from replacing edits
 372        //
 373        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 374        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 375            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 376            && old_edits.len() == 1
 377            && new_edits.len() == 1
 378        {
 379            let (old_range, old_text) = &old_edits[0];
 380            let (new_range, new_text) = &new_edits[0];
 381            new_range == old_range && new_text.starts_with(old_text.as_ref())
 382        } else {
 383            true
 384        }
 385    }
 386}
 387
 388#[derive(Debug, Clone)]
 389enum PredictionRequestedBy {
 390    DiagnosticsUpdate,
 391    Buffer(EntityId),
 392}
 393
 394impl PredictionRequestedBy {
 395    pub fn buffer_id(&self) -> Option<EntityId> {
 396        match self {
 397            PredictionRequestedBy::DiagnosticsUpdate => None,
 398            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 399        }
 400    }
 401}
 402
 403#[derive(Debug)]
 404struct PendingPrediction {
 405    id: usize,
 406    task: Task<Option<EditPredictionId>>,
 407}
 408
 409/// A prediction from the perspective of a buffer.
 410#[derive(Debug)]
 411enum BufferEditPrediction<'a> {
 412    Local { prediction: &'a EditPrediction },
 413    Jump { prediction: &'a EditPrediction },
 414}
 415
 416#[cfg(test)]
 417impl std::ops::Deref for BufferEditPrediction<'_> {
 418    type Target = EditPrediction;
 419
 420    fn deref(&self) -> &Self::Target {
 421        match self {
 422            BufferEditPrediction::Local { prediction } => prediction,
 423            BufferEditPrediction::Jump { prediction } => prediction,
 424        }
 425    }
 426}
 427
 428struct RegisteredBuffer {
 429    snapshot: BufferSnapshot,
 430    _subscriptions: [gpui::Subscription; 2],
 431}
 432
 433struct LastEvent {
 434    old_snapshot: BufferSnapshot,
 435    new_snapshot: BufferSnapshot,
 436    end_edit_anchor: Option<Anchor>,
 437}
 438
 439impl LastEvent {
 440    pub fn finalize(
 441        &self,
 442        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 443        cx: &App,
 444    ) -> Option<Arc<predict_edits_v3::Event>> {
 445        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 446        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 447
 448        let file = self.new_snapshot.file();
 449        let old_file = self.old_snapshot.file();
 450
 451        let in_open_source_repo = [file, old_file].iter().all(|file| {
 452            file.is_some_and(|file| {
 453                license_detection_watchers
 454                    .get(&file.worktree_id(cx))
 455                    .is_some_and(|watcher| watcher.is_project_open_source())
 456            })
 457        });
 458
 459        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 460
 461        if path == old_path && diff.is_empty() {
 462            None
 463        } else {
 464            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 465                old_path,
 466                path,
 467                diff,
 468                in_open_source_repo,
 469                // TODO: Actually detect if this edit was predicted or not
 470                predicted: false,
 471            }))
 472        }
 473    }
 474}
 475
 476fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 477    if let Some(file) = snapshot.file() {
 478        file.full_path(cx).into()
 479    } else {
 480        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 481    }
 482}
 483
 484impl Zeta {
 485    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 486        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 487    }
 488
 489    pub fn global(
 490        client: &Arc<Client>,
 491        user_store: &Entity<UserStore>,
 492        cx: &mut App,
 493    ) -> Entity<Self> {
 494        cx.try_global::<ZetaGlobal>()
 495            .map(|global| global.0.clone())
 496            .unwrap_or_else(|| {
 497                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 498                cx.set_global(ZetaGlobal(zeta.clone()));
 499                zeta
 500            })
 501    }
 502
 503    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 504        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 505        let data_collection_choice = Self::load_data_collection_choice();
 506
 507        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 508        cx.spawn(async move |this, cx| {
 509            while let Some(()) = reject_rx.next().await {
 510                this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
 511                    .await
 512                    .log_err();
 513            }
 514            anyhow::Ok(())
 515        })
 516        .detach();
 517
 518        Self {
 519            projects: HashMap::default(),
 520            client,
 521            user_store,
 522            options: DEFAULT_OPTIONS,
 523            llm_token: LlmApiToken::default(),
 524            _llm_token_subscription: cx.subscribe(
 525                &refresh_llm_token_listener,
 526                |this, _listener, _event, cx| {
 527                    let client = this.client.clone();
 528                    let llm_token = this.llm_token.clone();
 529                    cx.spawn(async move |_this, _cx| {
 530                        llm_token.refresh(&client).await?;
 531                        anyhow::Ok(())
 532                    })
 533                    .detach_and_log_err(cx);
 534                },
 535            ),
 536            update_required: false,
 537            debug_tx: None,
 538            #[cfg(feature = "eval-support")]
 539            eval_cache: None,
 540            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 541            sweep_ai: SweepAi::new(cx),
 542            data_collection_choice,
 543            rejected_predictions: Vec::new(),
 544            reject_predictions_debounce_task: None,
 545            reject_predictions_tx: reject_tx,
 546            rated_predictions: Default::default(),
 547            shown_predictions: Default::default(),
 548        }
 549    }
 550
 551    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 552        self.edit_prediction_model = model;
 553    }
 554
 555    pub fn has_sweep_api_token(&self) -> bool {
 556        self.sweep_ai.api_token.is_some()
 557    }
 558
 559    #[cfg(feature = "eval-support")]
 560    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 561        self.eval_cache = Some(cache);
 562    }
 563
 564    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 565        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 566        self.debug_tx = Some(debug_watch_tx);
 567        debug_watch_rx
 568    }
 569
 570    pub fn options(&self) -> &ZetaOptions {
 571        &self.options
 572    }
 573
 574    pub fn set_options(&mut self, options: ZetaOptions) {
 575        self.options = options;
 576    }
 577
 578    pub fn clear_history(&mut self) {
 579        for zeta_project in self.projects.values_mut() {
 580            zeta_project.events.clear();
 581        }
 582    }
 583
 584    pub fn context_for_project(
 585        &self,
 586        project: &Entity<Project>,
 587    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 588        self.projects
 589            .get(&project.entity_id())
 590            .and_then(|project| {
 591                Some(
 592                    project
 593                        .context
 594                        .as_ref()?
 595                        .iter()
 596                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 597                )
 598            })
 599            .into_iter()
 600            .flatten()
 601    }
 602
 603    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 604        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 605            self.user_store.read(cx).edit_prediction_usage()
 606        } else {
 607            None
 608        }
 609    }
 610
 611    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 612        self.get_or_init_zeta_project(project, cx);
 613    }
 614
 615    pub fn register_buffer(
 616        &mut self,
 617        buffer: &Entity<Buffer>,
 618        project: &Entity<Project>,
 619        cx: &mut Context<Self>,
 620    ) {
 621        let zeta_project = self.get_or_init_zeta_project(project, cx);
 622        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 623    }
 624
 625    fn get_or_init_zeta_project(
 626        &mut self,
 627        project: &Entity<Project>,
 628        cx: &mut Context<Self>,
 629    ) -> &mut ZetaProject {
 630        self.projects
 631            .entry(project.entity_id())
 632            .or_insert_with(|| ZetaProject {
 633                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 634                    Some(cx.new(|cx| {
 635                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 636                    }))
 637                } else {
 638                    None
 639                },
 640                events: VecDeque::new(),
 641                last_event: None,
 642                recent_paths: VecDeque::new(),
 643                registered_buffers: HashMap::default(),
 644                current_prediction: None,
 645                cancelled_predictions: HashSet::default(),
 646                pending_predictions: ArrayVec::new(),
 647                next_pending_prediction_id: 0,
 648                last_prediction_refresh: None,
 649                context: None,
 650                refresh_context_task: None,
 651                refresh_context_debounce_task: None,
 652                refresh_context_timestamp: None,
 653                license_detection_watchers: HashMap::default(),
 654                _subscription: cx.subscribe(&project, Self::handle_project_event),
 655            })
 656    }
 657
 658    fn handle_project_event(
 659        &mut self,
 660        project: Entity<Project>,
 661        event: &project::Event,
 662        cx: &mut Context<Self>,
 663    ) {
 664        // TODO [zeta2] init with recent paths
 665        match event {
 666            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 667                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 668                    return;
 669                };
 670                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 671                if let Some(path) = path {
 672                    if let Some(ix) = zeta_project
 673                        .recent_paths
 674                        .iter()
 675                        .position(|probe| probe == &path)
 676                    {
 677                        zeta_project.recent_paths.remove(ix);
 678                    }
 679                    zeta_project.recent_paths.push_front(path);
 680                }
 681            }
 682            project::Event::DiagnosticsUpdated { .. } => {
 683                if cx.has_flag::<Zeta2FeatureFlag>() {
 684                    self.refresh_prediction_from_diagnostics(project, cx);
 685                }
 686            }
 687            _ => (),
 688        }
 689    }
 690
 691    fn register_buffer_impl<'a>(
 692        zeta_project: &'a mut ZetaProject,
 693        buffer: &Entity<Buffer>,
 694        project: &Entity<Project>,
 695        cx: &mut Context<Self>,
 696    ) -> &'a mut RegisteredBuffer {
 697        let buffer_id = buffer.entity_id();
 698
 699        if let Some(file) = buffer.read(cx).file() {
 700            let worktree_id = file.worktree_id(cx);
 701            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 702                zeta_project
 703                    .license_detection_watchers
 704                    .entry(worktree_id)
 705                    .or_insert_with(|| {
 706                        let project_entity_id = project.entity_id();
 707                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 708                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 709                            else {
 710                                return;
 711                            };
 712                            zeta_project.license_detection_watchers.remove(&worktree_id);
 713                        })
 714                        .detach();
 715                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 716                    });
 717            }
 718        }
 719
 720        match zeta_project.registered_buffers.entry(buffer_id) {
 721            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 722            hash_map::Entry::Vacant(entry) => {
 723                let snapshot = buffer.read(cx).snapshot();
 724                let project_entity_id = project.entity_id();
 725                entry.insert(RegisteredBuffer {
 726                    snapshot,
 727                    _subscriptions: [
 728                        cx.subscribe(buffer, {
 729                            let project = project.downgrade();
 730                            move |this, buffer, event, cx| {
 731                                if let language::BufferEvent::Edited = event
 732                                    && let Some(project) = project.upgrade()
 733                                {
 734                                    this.report_changes_for_buffer(&buffer, &project, cx);
 735                                }
 736                            }
 737                        }),
 738                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 739                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 740                            else {
 741                                return;
 742                            };
 743                            zeta_project.registered_buffers.remove(&buffer_id);
 744                        }),
 745                    ],
 746                })
 747            }
 748        }
 749    }
 750
 751    fn report_changes_for_buffer(
 752        &mut self,
 753        buffer: &Entity<Buffer>,
 754        project: &Entity<Project>,
 755        cx: &mut Context<Self>,
 756    ) {
 757        let project_state = self.get_or_init_zeta_project(project, cx);
 758        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 759
 760        let new_snapshot = buffer.read(cx).snapshot();
 761        if new_snapshot.version == registered_buffer.snapshot.version {
 762            return;
 763        }
 764
 765        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 766        let end_edit_anchor = new_snapshot
 767            .anchored_edits_since::<Point>(&old_snapshot.version)
 768            .last()
 769            .map(|(_, range)| range.end);
 770        let events = &mut project_state.events;
 771
 772        if let Some(LastEvent {
 773            new_snapshot: last_new_snapshot,
 774            end_edit_anchor: last_end_edit_anchor,
 775            ..
 776        }) = project_state.last_event.as_mut()
 777        {
 778            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 779                == last_new_snapshot.remote_id()
 780                && old_snapshot.version == last_new_snapshot.version;
 781
 782            let should_coalesce = is_next_snapshot_of_same_buffer
 783                && end_edit_anchor
 784                    .as_ref()
 785                    .zip(last_end_edit_anchor.as_ref())
 786                    .is_some_and(|(a, b)| {
 787                        let a = a.to_point(&new_snapshot);
 788                        let b = b.to_point(&new_snapshot);
 789                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 790                    });
 791
 792            if should_coalesce {
 793                *last_end_edit_anchor = end_edit_anchor;
 794                *last_new_snapshot = new_snapshot;
 795                return;
 796            }
 797        }
 798
 799        if events.len() + 1 >= EVENT_COUNT_MAX {
 800            events.pop_front();
 801        }
 802
 803        if let Some(event) = project_state.last_event.take() {
 804            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 805        }
 806
 807        project_state.last_event = Some(LastEvent {
 808            old_snapshot,
 809            new_snapshot,
 810            end_edit_anchor,
 811        });
 812    }
 813
 814    fn current_prediction_for_buffer(
 815        &self,
 816        buffer: &Entity<Buffer>,
 817        project: &Entity<Project>,
 818        cx: &App,
 819    ) -> Option<BufferEditPrediction<'_>> {
 820        let project_state = self.projects.get(&project.entity_id())?;
 821
 822        let CurrentEditPrediction {
 823            requested_by,
 824            prediction,
 825            ..
 826        } = project_state.current_prediction.as_ref()?;
 827
 828        if prediction.targets_buffer(buffer.read(cx)) {
 829            Some(BufferEditPrediction::Local { prediction })
 830        } else {
 831            let show_jump = match requested_by {
 832                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 833                    requested_by_buffer_id == &buffer.entity_id()
 834                }
 835                PredictionRequestedBy::DiagnosticsUpdate => true,
 836            };
 837
 838            if show_jump {
 839                Some(BufferEditPrediction::Jump { prediction })
 840            } else {
 841                None
 842            }
 843        }
 844    }
 845
 846    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 847        match self.edit_prediction_model {
 848            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 849            ZetaEditPredictionModel::Sweep => return,
 850        }
 851
 852        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 853            return;
 854        };
 855
 856        let Some(prediction) = project_state.current_prediction.take() else {
 857            return;
 858        };
 859        let request_id = prediction.prediction.id.to_string();
 860        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 861            project_state.cancel_pending_prediction(pending_prediction, cx);
 862        }
 863
 864        let client = self.client.clone();
 865        let llm_token = self.llm_token.clone();
 866        let app_version = AppVersion::global(cx);
 867        cx.spawn(async move |this, cx| {
 868            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 869                http_client::Url::parse(&predict_edits_url)?
 870            } else {
 871                client
 872                    .http_client()
 873                    .build_zed_llm_url("/predict_edits/accept", &[])?
 874            };
 875
 876            let response = cx
 877                .background_spawn(Self::send_api_request::<()>(
 878                    move |builder| {
 879                        let req = builder.uri(url.as_ref()).body(
 880                            serde_json::to_string(&AcceptEditPredictionBody {
 881                                request_id: request_id.clone(),
 882                            })?
 883                            .into(),
 884                        );
 885                        Ok(req?)
 886                    },
 887                    client,
 888                    llm_token,
 889                    app_version,
 890                ))
 891                .await;
 892
 893            Self::handle_api_response(&this, response, cx)?;
 894            anyhow::Ok(())
 895        })
 896        .detach_and_log_err(cx);
 897    }
 898
 899    fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 900        match self.edit_prediction_model {
 901            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 902            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 903        }
 904
 905        let client = self.client.clone();
 906        let llm_token = self.llm_token.clone();
 907        let app_version = AppVersion::global(cx);
 908        let last_rejection = self.rejected_predictions.last().cloned();
 909        let Some(last_rejection) = last_rejection else {
 910            return Task::ready(anyhow::Ok(()));
 911        };
 912
 913        let body = serde_json::to_string(&RejectEditPredictionsBody {
 914            rejections: self.rejected_predictions.clone(),
 915        })
 916        .ok();
 917
 918        cx.spawn(async move |this, cx| {
 919            let url = client
 920                .http_client()
 921                .build_zed_llm_url("/predict_edits/reject", &[])?;
 922
 923            cx.background_spawn(Self::send_api_request::<()>(
 924                move |builder| {
 925                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 926                    Ok(req?)
 927                },
 928                client,
 929                llm_token,
 930                app_version,
 931            ))
 932            .await
 933            .context("Failed to reject edit predictions")?;
 934
 935            this.update(cx, |this, _| {
 936                if let Some(ix) = this
 937                    .rejected_predictions
 938                    .iter()
 939                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 940                {
 941                    this.rejected_predictions.drain(..ix + 1);
 942                }
 943            })
 944        })
 945    }
 946
 947    fn reject_current_prediction(
 948        &mut self,
 949        reason: EditPredictionRejectReason,
 950        project: &Entity<Project>,
 951        cx: &mut Context<Self>,
 952    ) {
 953        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 954            project_state.pending_predictions.clear();
 955            if let Some(prediction) = project_state.current_prediction.take() {
 956                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
 957            }
 958        };
 959    }
 960
 961    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 962        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 963            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 964                if !current_prediction.was_shown {
 965                    current_prediction.was_shown = true;
 966                    self.shown_predictions
 967                        .push_front(current_prediction.prediction.clone());
 968                    if self.shown_predictions.len() > 50 {
 969                        let completion = self.shown_predictions.pop_back().unwrap();
 970                        self.rated_predictions.remove(&completion.id);
 971                    }
 972                }
 973            }
 974        }
 975    }
 976
 977    fn reject_prediction(
 978        &mut self,
 979        prediction_id: EditPredictionId,
 980        reason: EditPredictionRejectReason,
 981        was_shown: bool,
 982        cx: &mut Context<Self>,
 983    ) {
 984        self.rejected_predictions.push(EditPredictionRejection {
 985            request_id: prediction_id.to_string(),
 986            reason,
 987            was_shown,
 988        });
 989
 990        let reached_request_limit =
 991            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
 992        let reject_tx = self.reject_predictions_tx.clone();
 993        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 994            const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 995            if !reached_request_limit {
 996                cx.background_executor()
 997                    .timer(REJECT_REQUEST_DEBOUNCE)
 998                    .await;
 999            }
1000            reject_tx.unbounded_send(()).log_err();
1001        }));
1002    }
1003
1004    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1005        self.projects
1006            .get(&project.entity_id())
1007            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1008    }
1009
1010    pub fn refresh_prediction_from_buffer(
1011        &mut self,
1012        project: Entity<Project>,
1013        buffer: Entity<Buffer>,
1014        position: language::Anchor,
1015        cx: &mut Context<Self>,
1016    ) {
1017        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1018            let Some(request_task) = this
1019                .update(cx, |this, cx| {
1020                    this.request_prediction(
1021                        &project,
1022                        &buffer,
1023                        position,
1024                        PredictEditsRequestTrigger::Other,
1025                        cx,
1026                    )
1027                })
1028                .log_err()
1029            else {
1030                return Task::ready(anyhow::Ok(None));
1031            };
1032
1033            cx.spawn(async move |_cx| {
1034                request_task.await.map(|prediction_result| {
1035                    prediction_result.map(|prediction_result| {
1036                        (
1037                            prediction_result,
1038                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1039                        )
1040                    })
1041                })
1042            })
1043        })
1044    }
1045
1046    pub fn refresh_prediction_from_diagnostics(
1047        &mut self,
1048        project: Entity<Project>,
1049        cx: &mut Context<Self>,
1050    ) {
1051        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1052            return;
1053        };
1054
1055        // Prefer predictions from buffer
1056        if zeta_project.current_prediction.is_some() {
1057            return;
1058        };
1059
1060        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1061            let Some(open_buffer_task) = project
1062                .update(cx, |project, cx| {
1063                    project
1064                        .active_entry()
1065                        .and_then(|entry| project.path_for_entry(entry, cx))
1066                        .map(|path| project.open_buffer(path, cx))
1067                })
1068                .log_err()
1069                .flatten()
1070            else {
1071                return Task::ready(anyhow::Ok(None));
1072            };
1073
1074            cx.spawn(async move |cx| {
1075                let active_buffer = open_buffer_task.await?;
1076                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1077
1078                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1079                    active_buffer,
1080                    &snapshot,
1081                    Default::default(),
1082                    Default::default(),
1083                    &project,
1084                    cx,
1085                )
1086                .await?
1087                else {
1088                    return anyhow::Ok(None);
1089                };
1090
1091                let Some(prediction_result) = this
1092                    .update(cx, |this, cx| {
1093                        this.request_prediction(
1094                            &project,
1095                            &jump_buffer,
1096                            jump_position,
1097                            PredictEditsRequestTrigger::Diagnostics,
1098                            cx,
1099                        )
1100                    })?
1101                    .await?
1102                else {
1103                    return anyhow::Ok(None);
1104                };
1105
1106                this.update(cx, |this, cx| {
1107                    Some((
1108                        if this
1109                            .get_or_init_zeta_project(&project, cx)
1110                            .current_prediction
1111                            .is_none()
1112                        {
1113                            prediction_result
1114                        } else {
1115                            EditPredictionResult {
1116                                id: prediction_result.id,
1117                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1118                            }
1119                        },
1120                        PredictionRequestedBy::DiagnosticsUpdate,
1121                    ))
1122                })
1123            })
1124        });
1125    }
1126
1127    #[cfg(not(test))]
1128    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1129    #[cfg(test)]
1130    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1131
1132    fn queue_prediction_refresh(
1133        &mut self,
1134        project: Entity<Project>,
1135        throttle_entity: EntityId,
1136        cx: &mut Context<Self>,
1137        do_refresh: impl FnOnce(
1138            WeakEntity<Self>,
1139            &mut AsyncApp,
1140        )
1141            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1142        + 'static,
1143    ) {
1144        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1145        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1146        zeta_project.next_pending_prediction_id += 1;
1147        let last_request = zeta_project.last_prediction_refresh;
1148
1149        let task = cx.spawn(async move |this, cx| {
1150            if let Some((last_entity, last_timestamp)) = last_request
1151                && throttle_entity == last_entity
1152                && let Some(timeout) =
1153                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1154            {
1155                cx.background_executor().timer(timeout).await;
1156            }
1157
1158            // If this task was cancelled before the throttle timeout expired,
1159            // do not perform a request.
1160            let mut is_cancelled = true;
1161            this.update(cx, |this, cx| {
1162                let project_state = this.get_or_init_zeta_project(&project, cx);
1163                if !project_state
1164                    .cancelled_predictions
1165                    .remove(&pending_prediction_id)
1166                {
1167                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1168                    is_cancelled = false;
1169                }
1170            })
1171            .ok();
1172            if is_cancelled {
1173                return None;
1174            }
1175
1176            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1177            let new_prediction_id = new_prediction_result
1178                .as_ref()
1179                .map(|(prediction, _)| prediction.id.clone());
1180
1181            // When a prediction completes, remove it from the pending list, and cancel
1182            // any pending predictions that were enqueued before it.
1183            this.update(cx, |this, cx| {
1184                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1185
1186                let is_cancelled = zeta_project
1187                    .cancelled_predictions
1188                    .remove(&pending_prediction_id);
1189
1190                let new_current_prediction = if !is_cancelled
1191                    && let Some((prediction_result, requested_by)) = new_prediction_result
1192                {
1193                    match prediction_result.prediction {
1194                        Ok(prediction) => {
1195                            let new_prediction = CurrentEditPrediction {
1196                                requested_by,
1197                                prediction,
1198                                was_shown: false,
1199                            };
1200
1201                            if let Some(current_prediction) =
1202                                zeta_project.current_prediction.as_ref()
1203                            {
1204                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1205                                {
1206                                    this.reject_current_prediction(
1207                                        EditPredictionRejectReason::Replaced,
1208                                        &project,
1209                                        cx,
1210                                    );
1211
1212                                    Some(new_prediction)
1213                                } else {
1214                                    this.reject_prediction(
1215                                        new_prediction.prediction.id,
1216                                        EditPredictionRejectReason::CurrentPreferred,
1217                                        false,
1218                                        cx,
1219                                    );
1220                                    None
1221                                }
1222                            } else {
1223                                Some(new_prediction)
1224                            }
1225                        }
1226                        Err(reject_reason) => {
1227                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1228                            None
1229                        }
1230                    }
1231                } else {
1232                    None
1233                };
1234
1235                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1236
1237                if let Some(new_prediction) = new_current_prediction {
1238                    zeta_project.current_prediction = Some(new_prediction);
1239                }
1240
1241                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1242                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1243                    if pending_prediction.id == pending_prediction_id {
1244                        pending_predictions.remove(ix);
1245                        for pending_prediction in pending_predictions.drain(0..ix) {
1246                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1247                        }
1248                        break;
1249                    }
1250                }
1251                this.get_or_init_zeta_project(&project, cx)
1252                    .pending_predictions = pending_predictions;
1253                cx.notify();
1254            })
1255            .ok();
1256
1257            new_prediction_id
1258        });
1259
1260        if zeta_project.pending_predictions.len() <= 1 {
1261            zeta_project.pending_predictions.push(PendingPrediction {
1262                id: pending_prediction_id,
1263                task,
1264            });
1265        } else if zeta_project.pending_predictions.len() == 2 {
1266            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1267            zeta_project.pending_predictions.push(PendingPrediction {
1268                id: pending_prediction_id,
1269                task,
1270            });
1271            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1272        }
1273    }
1274
1275    pub fn request_prediction(
1276        &mut self,
1277        project: &Entity<Project>,
1278        active_buffer: &Entity<Buffer>,
1279        position: language::Anchor,
1280        trigger: PredictEditsRequestTrigger,
1281        cx: &mut Context<Self>,
1282    ) -> Task<Result<Option<EditPredictionResult>>> {
1283        self.request_prediction_internal(
1284            project.clone(),
1285            active_buffer.clone(),
1286            position,
1287            trigger,
1288            cx.has_flag::<Zeta2FeatureFlag>(),
1289            cx,
1290        )
1291    }
1292
1293    fn request_prediction_internal(
1294        &mut self,
1295        project: Entity<Project>,
1296        active_buffer: Entity<Buffer>,
1297        position: language::Anchor,
1298        trigger: PredictEditsRequestTrigger,
1299        allow_jump: bool,
1300        cx: &mut Context<Self>,
1301    ) -> Task<Result<Option<EditPredictionResult>>> {
1302        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1303
1304        self.get_or_init_zeta_project(&project, cx);
1305        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1306        let events = zeta_project.events(cx);
1307        let has_events = !events.is_empty();
1308
1309        let snapshot = active_buffer.read(cx).snapshot();
1310        let cursor_point = position.to_point(&snapshot);
1311        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1312        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1313        let diagnostic_search_range =
1314            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1315
1316        let task = match self.edit_prediction_model {
1317            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1318                self,
1319                &project,
1320                &active_buffer,
1321                snapshot.clone(),
1322                position,
1323                events,
1324                trigger,
1325                cx,
1326            ),
1327            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1328                &project,
1329                &active_buffer,
1330                snapshot.clone(),
1331                position,
1332                events,
1333                trigger,
1334                cx,
1335            ),
1336            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1337                &project,
1338                &active_buffer,
1339                snapshot.clone(),
1340                position,
1341                events,
1342                &zeta_project.recent_paths,
1343                diagnostic_search_range.clone(),
1344                cx,
1345            ),
1346        };
1347
1348        cx.spawn(async move |this, cx| {
1349            let prediction = task.await?;
1350
1351            if prediction.is_none() && allow_jump {
1352                let cursor_point = position.to_point(&snapshot);
1353                if has_events
1354                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1355                        active_buffer.clone(),
1356                        &snapshot,
1357                        diagnostic_search_range,
1358                        cursor_point,
1359                        &project,
1360                        cx,
1361                    )
1362                    .await?
1363                {
1364                    return this
1365                        .update(cx, |this, cx| {
1366                            this.request_prediction_internal(
1367                                project,
1368                                jump_buffer,
1369                                jump_position,
1370                                trigger,
1371                                false,
1372                                cx,
1373                            )
1374                        })?
1375                        .await;
1376                }
1377
1378                return anyhow::Ok(None);
1379            }
1380
1381            Ok(prediction)
1382        })
1383    }
1384
1385    async fn next_diagnostic_location(
1386        active_buffer: Entity<Buffer>,
1387        active_buffer_snapshot: &BufferSnapshot,
1388        active_buffer_diagnostic_search_range: Range<Point>,
1389        active_buffer_cursor_point: Point,
1390        project: &Entity<Project>,
1391        cx: &mut AsyncApp,
1392    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1393        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1394        let mut jump_location = active_buffer_snapshot
1395            .diagnostic_groups(None)
1396            .into_iter()
1397            .filter_map(|(_, group)| {
1398                let range = &group.entries[group.primary_ix]
1399                    .range
1400                    .to_point(&active_buffer_snapshot);
1401                if range.overlaps(&active_buffer_diagnostic_search_range) {
1402                    None
1403                } else {
1404                    Some(range.start)
1405                }
1406            })
1407            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1408            .map(|position| {
1409                (
1410                    active_buffer.clone(),
1411                    active_buffer_snapshot.anchor_before(position),
1412                )
1413            });
1414
1415        if jump_location.is_none() {
1416            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1417                let file = buffer.file()?;
1418
1419                Some(ProjectPath {
1420                    worktree_id: file.worktree_id(cx),
1421                    path: file.path().clone(),
1422                })
1423            })?;
1424
1425            let buffer_task = project.update(cx, |project, cx| {
1426                let (path, _, _) = project
1427                    .diagnostic_summaries(false, cx)
1428                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1429                    .max_by_key(|(path, _, _)| {
1430                        // find the buffer with errors that shares most parent directories
1431                        path.path
1432                            .components()
1433                            .zip(
1434                                active_buffer_path
1435                                    .as_ref()
1436                                    .map(|p| p.path.components())
1437                                    .unwrap_or_default(),
1438                            )
1439                            .take_while(|(a, b)| a == b)
1440                            .count()
1441                    })?;
1442
1443                Some(project.open_buffer(path, cx))
1444            })?;
1445
1446            if let Some(buffer_task) = buffer_task {
1447                let closest_buffer = buffer_task.await?;
1448
1449                jump_location = closest_buffer
1450                    .read_with(cx, |buffer, _cx| {
1451                        buffer
1452                            .buffer_diagnostics(None)
1453                            .into_iter()
1454                            .min_by_key(|entry| entry.diagnostic.severity)
1455                            .map(|entry| entry.range.start)
1456                    })?
1457                    .map(|position| (closest_buffer, position));
1458            }
1459        }
1460
1461        anyhow::Ok(jump_location)
1462    }
1463
1464    fn request_prediction_with_zeta2(
1465        &mut self,
1466        project: &Entity<Project>,
1467        active_buffer: &Entity<Buffer>,
1468        active_snapshot: BufferSnapshot,
1469        position: language::Anchor,
1470        events: Vec<Arc<Event>>,
1471        trigger: PredictEditsRequestTrigger,
1472        cx: &mut Context<Self>,
1473    ) -> Task<Result<Option<EditPredictionResult>>> {
1474        let project_state = self.projects.get(&project.entity_id());
1475
1476        let index_state = project_state.and_then(|state| {
1477            state
1478                .syntax_index
1479                .as_ref()
1480                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1481        });
1482        let options = self.options.clone();
1483        let buffer_snapshotted_at = Instant::now();
1484        let Some(excerpt_path) = active_snapshot
1485            .file()
1486            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1487        else {
1488            return Task::ready(Err(anyhow!("No file path for excerpt")));
1489        };
1490        let client = self.client.clone();
1491        let llm_token = self.llm_token.clone();
1492        let app_version = AppVersion::global(cx);
1493        let worktree_snapshots = project
1494            .read(cx)
1495            .worktrees(cx)
1496            .map(|worktree| worktree.read(cx).snapshot())
1497            .collect::<Vec<_>>();
1498        let debug_tx = self.debug_tx.clone();
1499
1500        let diagnostics = active_snapshot.diagnostic_sets().clone();
1501
1502        let file = active_buffer.read(cx).file();
1503        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1504            let mut path = f.worktree.read(cx).absolutize(&f.path);
1505            if path.pop() { Some(path) } else { None }
1506        });
1507
1508        // TODO data collection
1509        let can_collect_data = file
1510            .as_ref()
1511            .map_or(false, |file| self.can_collect_file(project, file, cx));
1512
1513        let empty_context_files = HashMap::default();
1514        let context_files = project_state
1515            .and_then(|project_state| project_state.context.as_ref())
1516            .unwrap_or(&empty_context_files);
1517
1518        #[cfg(feature = "eval-support")]
1519        let parsed_fut = futures::future::join_all(
1520            context_files
1521                .keys()
1522                .map(|buffer| buffer.read(cx).parsing_idle()),
1523        );
1524
1525        let mut included_files = context_files
1526            .iter()
1527            .filter_map(|(buffer_entity, ranges)| {
1528                let buffer = buffer_entity.read(cx);
1529                Some((
1530                    buffer_entity.clone(),
1531                    buffer.snapshot(),
1532                    buffer.file()?.full_path(cx).into(),
1533                    ranges.clone(),
1534                ))
1535            })
1536            .collect::<Vec<_>>();
1537
1538        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1539            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1540        });
1541
1542        #[cfg(feature = "eval-support")]
1543        let eval_cache = self.eval_cache.clone();
1544
1545        let request_task = cx.background_spawn({
1546            let active_buffer = active_buffer.clone();
1547            async move {
1548                #[cfg(feature = "eval-support")]
1549                parsed_fut.await;
1550
1551                let index_state = if let Some(index_state) = index_state {
1552                    Some(index_state.lock_owned().await)
1553                } else {
1554                    None
1555                };
1556
1557                let cursor_offset = position.to_offset(&active_snapshot);
1558                let cursor_point = cursor_offset.to_point(&active_snapshot);
1559
1560                let before_retrieval = Instant::now();
1561
1562                let (diagnostic_groups, diagnostic_groups_truncated) =
1563                    Self::gather_nearby_diagnostics(
1564                        cursor_offset,
1565                        &diagnostics,
1566                        &active_snapshot,
1567                        options.max_diagnostic_bytes,
1568                    );
1569
1570                let cloud_request = match options.context {
1571                    ContextMode::Agentic(context_options) => {
1572                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1573                            cursor_point,
1574                            &active_snapshot,
1575                            &context_options.excerpt,
1576                            index_state.as_deref(),
1577                        ) else {
1578                            return Ok((None, None));
1579                        };
1580
1581                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1582                            ..active_snapshot.anchor_before(excerpt.range.end);
1583
1584                        if let Some(buffer_ix) =
1585                            included_files.iter().position(|(_, snapshot, _, _)| {
1586                                snapshot.remote_id() == active_snapshot.remote_id()
1587                            })
1588                        {
1589                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1590                            ranges.push(excerpt_anchor_range);
1591                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1592                            let last_ix = included_files.len() - 1;
1593                            included_files.swap(buffer_ix, last_ix);
1594                        } else {
1595                            included_files.push((
1596                                active_buffer.clone(),
1597                                active_snapshot.clone(),
1598                                excerpt_path.clone(),
1599                                vec![excerpt_anchor_range],
1600                            ));
1601                        }
1602
1603                        let included_files = included_files
1604                            .iter()
1605                            .map(|(_, snapshot, path, ranges)| {
1606                                let ranges = ranges
1607                                    .iter()
1608                                    .map(|range| {
1609                                        let point_range = range.to_point(&snapshot);
1610                                        Line(point_range.start.row)..Line(point_range.end.row)
1611                                    })
1612                                    .collect::<Vec<_>>();
1613                                let excerpts = assemble_excerpts(&snapshot, ranges);
1614                                predict_edits_v3::IncludedFile {
1615                                    path: path.clone(),
1616                                    max_row: Line(snapshot.max_point().row),
1617                                    excerpts,
1618                                }
1619                            })
1620                            .collect::<Vec<_>>();
1621
1622                        predict_edits_v3::PredictEditsRequest {
1623                            excerpt_path,
1624                            excerpt: String::new(),
1625                            excerpt_line_range: Line(0)..Line(0),
1626                            excerpt_range: 0..0,
1627                            cursor_point: predict_edits_v3::Point {
1628                                line: predict_edits_v3::Line(cursor_point.row),
1629                                column: cursor_point.column,
1630                            },
1631                            included_files,
1632                            referenced_declarations: vec![],
1633                            events,
1634                            can_collect_data,
1635                            diagnostic_groups,
1636                            diagnostic_groups_truncated,
1637                            debug_info: debug_tx.is_some(),
1638                            prompt_max_bytes: Some(options.max_prompt_bytes),
1639                            prompt_format: options.prompt_format,
1640                            // TODO [zeta2]
1641                            signatures: vec![],
1642                            excerpt_parent: None,
1643                            git_info: None,
1644                            trigger,
1645                        }
1646                    }
1647                    ContextMode::Syntax(context_options) => {
1648                        let Some(context) = EditPredictionContext::gather_context(
1649                            cursor_point,
1650                            &active_snapshot,
1651                            parent_abs_path.as_deref(),
1652                            &context_options,
1653                            index_state.as_deref(),
1654                        ) else {
1655                            return Ok((None, None));
1656                        };
1657
1658                        make_syntax_context_cloud_request(
1659                            excerpt_path,
1660                            context,
1661                            events,
1662                            can_collect_data,
1663                            diagnostic_groups,
1664                            diagnostic_groups_truncated,
1665                            None,
1666                            debug_tx.is_some(),
1667                            &worktree_snapshots,
1668                            index_state.as_deref(),
1669                            Some(options.max_prompt_bytes),
1670                            options.prompt_format,
1671                            trigger,
1672                        )
1673                    }
1674                };
1675
1676                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1677
1678                let inputs = EditPredictionInputs {
1679                    included_files: cloud_request.included_files,
1680                    events: cloud_request.events,
1681                    cursor_point: cloud_request.cursor_point,
1682                    cursor_path: cloud_request.excerpt_path,
1683                };
1684
1685                let retrieval_time = Instant::now() - before_retrieval;
1686
1687                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1688                    let (response_tx, response_rx) = oneshot::channel();
1689
1690                    debug_tx
1691                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1692                            ZetaEditPredictionDebugInfo {
1693                                inputs: inputs.clone(),
1694                                retrieval_time,
1695                                buffer: active_buffer.downgrade(),
1696                                local_prompt: match prompt_result.as_ref() {
1697                                    Ok((prompt, _)) => Ok(prompt.clone()),
1698                                    Err(err) => Err(err.to_string()),
1699                                },
1700                                position,
1701                                response_rx,
1702                            },
1703                        ))
1704                        .ok();
1705                    Some(response_tx)
1706                } else {
1707                    None
1708                };
1709
1710                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1711                    if let Some(debug_response_tx) = debug_response_tx {
1712                        debug_response_tx
1713                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1714                            .ok();
1715                    }
1716                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1717                }
1718
1719                let (prompt, _) = prompt_result?;
1720                let generation_params =
1721                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1722                let request = open_ai::Request {
1723                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1724                    messages: vec![open_ai::RequestMessage::User {
1725                        content: open_ai::MessageContent::Plain(prompt),
1726                    }],
1727                    stream: false,
1728                    max_completion_tokens: None,
1729                    stop: generation_params.stop.unwrap_or_default(),
1730                    temperature: generation_params.temperature.unwrap_or(0.7),
1731                    tool_choice: None,
1732                    parallel_tool_calls: None,
1733                    tools: vec![],
1734                    prompt_cache_key: None,
1735                    reasoning_effort: None,
1736                };
1737
1738                log::trace!("Sending edit prediction request");
1739
1740                let before_request = Instant::now();
1741                let response = Self::send_raw_llm_request(
1742                    request,
1743                    client,
1744                    llm_token,
1745                    app_version,
1746                    #[cfg(feature = "eval-support")]
1747                    eval_cache,
1748                    #[cfg(feature = "eval-support")]
1749                    EvalCacheEntryKind::Prediction,
1750                )
1751                .await;
1752                let received_response_at = Instant::now();
1753                let request_time = received_response_at - before_request;
1754
1755                log::trace!("Got edit prediction response");
1756
1757                if let Some(debug_response_tx) = debug_response_tx {
1758                    debug_response_tx
1759                        .send((
1760                            response
1761                                .as_ref()
1762                                .map_err(|err| err.to_string())
1763                                .map(|response| response.0.clone()),
1764                            request_time,
1765                        ))
1766                        .ok();
1767                }
1768
1769                let (res, usage) = response?;
1770                let request_id = EditPredictionId(res.id.clone().into());
1771                let Some(mut output_text) = text_from_response(res) else {
1772                    return Ok((Some((request_id, None)), usage));
1773                };
1774
1775                if output_text.contains(CURSOR_MARKER) {
1776                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1777                    output_text = output_text.replace(CURSOR_MARKER, "");
1778                }
1779
1780                let get_buffer_from_context = |path: &Path| {
1781                    included_files
1782                        .iter()
1783                        .find_map(|(_, buffer, probe_path, ranges)| {
1784                            if probe_path.as_ref() == path {
1785                                Some((buffer, ranges.as_slice()))
1786                            } else {
1787                                None
1788                            }
1789                        })
1790                };
1791
1792                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1793                    PromptFormat::NumLinesUniDiff => {
1794                        // TODO: Implement parsing of multi-file diffs
1795                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1796                    }
1797                    PromptFormat::Minimal
1798                    | PromptFormat::MinimalQwen
1799                    | PromptFormat::SeedCoder1120 => {
1800                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1801                            let edits = vec![];
1802                            (&active_snapshot, edits)
1803                        } else {
1804                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1805                        }
1806                    }
1807                    PromptFormat::OldTextNewText => {
1808                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1809                            .await?
1810                    }
1811                    _ => {
1812                        bail!("unsupported prompt format {}", options.prompt_format)
1813                    }
1814                };
1815
1816                let edited_buffer = included_files
1817                    .iter()
1818                    .find_map(|(buffer, snapshot, _, _)| {
1819                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1820                            Some(buffer.clone())
1821                        } else {
1822                            None
1823                        }
1824                    })
1825                    .context("Failed to find buffer in included_buffers")?;
1826
1827                anyhow::Ok((
1828                    Some((
1829                        request_id,
1830                        Some((
1831                            inputs,
1832                            edited_buffer,
1833                            edited_buffer_snapshot.clone(),
1834                            edits,
1835                            received_response_at,
1836                        )),
1837                    )),
1838                    usage,
1839                ))
1840            }
1841        });
1842
1843        cx.spawn({
1844            async move |this, cx| {
1845                let Some((id, prediction)) =
1846                    Self::handle_api_response(&this, request_task.await, cx)?
1847                else {
1848                    return Ok(None);
1849                };
1850
1851                let Some((
1852                    inputs,
1853                    edited_buffer,
1854                    edited_buffer_snapshot,
1855                    edits,
1856                    received_response_at,
1857                )) = prediction
1858                else {
1859                    return Ok(Some(EditPredictionResult {
1860                        id,
1861                        prediction: Err(EditPredictionRejectReason::Empty),
1862                    }));
1863                };
1864
1865                // TODO telemetry: duration, etc
1866                Ok(Some(
1867                    EditPredictionResult::new(
1868                        id,
1869                        &edited_buffer,
1870                        &edited_buffer_snapshot,
1871                        edits.into(),
1872                        buffer_snapshotted_at,
1873                        received_response_at,
1874                        inputs,
1875                        cx,
1876                    )
1877                    .await,
1878                ))
1879            }
1880        })
1881    }
1882
1883    async fn send_raw_llm_request(
1884        request: open_ai::Request,
1885        client: Arc<Client>,
1886        llm_token: LlmApiToken,
1887        app_version: Version,
1888        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1889        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1890    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1891        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1892            http_client::Url::parse(&predict_edits_url)?
1893        } else {
1894            client
1895                .http_client()
1896                .build_zed_llm_url("/predict_edits/raw", &[])?
1897        };
1898
1899        #[cfg(feature = "eval-support")]
1900        let cache_key = if let Some(cache) = eval_cache {
1901            use collections::FxHasher;
1902            use std::hash::{Hash, Hasher};
1903
1904            let mut hasher = FxHasher::default();
1905            url.hash(&mut hasher);
1906            let request_str = serde_json::to_string_pretty(&request)?;
1907            request_str.hash(&mut hasher);
1908            let hash = hasher.finish();
1909
1910            let key = (eval_cache_kind, hash);
1911            if let Some(response_str) = cache.read(key) {
1912                return Ok((serde_json::from_str(&response_str)?, None));
1913            }
1914
1915            Some((cache, request_str, key))
1916        } else {
1917            None
1918        };
1919
1920        let (response, usage) = Self::send_api_request(
1921            |builder| {
1922                let req = builder
1923                    .uri(url.as_ref())
1924                    .body(serde_json::to_string(&request)?.into());
1925                Ok(req?)
1926            },
1927            client,
1928            llm_token,
1929            app_version,
1930        )
1931        .await?;
1932
1933        #[cfg(feature = "eval-support")]
1934        if let Some((cache, request, key)) = cache_key {
1935            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1936        }
1937
1938        Ok((response, usage))
1939    }
1940
1941    fn handle_api_response<T>(
1942        this: &WeakEntity<Self>,
1943        response: Result<(T, Option<EditPredictionUsage>)>,
1944        cx: &mut gpui::AsyncApp,
1945    ) -> Result<T> {
1946        match response {
1947            Ok((data, usage)) => {
1948                if let Some(usage) = usage {
1949                    this.update(cx, |this, cx| {
1950                        this.user_store.update(cx, |user_store, cx| {
1951                            user_store.update_edit_prediction_usage(usage, cx);
1952                        });
1953                    })
1954                    .ok();
1955                }
1956                Ok(data)
1957            }
1958            Err(err) => {
1959                if err.is::<ZedUpdateRequiredError>() {
1960                    cx.update(|cx| {
1961                        this.update(cx, |this, _cx| {
1962                            this.update_required = true;
1963                        })
1964                        .ok();
1965
1966                        let error_message: SharedString = err.to_string().into();
1967                        show_app_notification(
1968                            NotificationId::unique::<ZedUpdateRequiredError>(),
1969                            cx,
1970                            move |cx| {
1971                                cx.new(|cx| {
1972                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1973                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1974                                })
1975                            },
1976                        );
1977                    })
1978                    .ok();
1979                }
1980                Err(err)
1981            }
1982        }
1983    }
1984
1985    async fn send_api_request<Res>(
1986        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1987        client: Arc<Client>,
1988        llm_token: LlmApiToken,
1989        app_version: Version,
1990    ) -> Result<(Res, Option<EditPredictionUsage>)>
1991    where
1992        Res: DeserializeOwned,
1993    {
1994        let http_client = client.http_client();
1995        let mut token = llm_token.acquire(&client).await?;
1996        let mut did_retry = false;
1997
1998        loop {
1999            let request_builder = http_client::Request::builder().method(Method::POST);
2000
2001            let request = build(
2002                request_builder
2003                    .header("Content-Type", "application/json")
2004                    .header("Authorization", format!("Bearer {}", token))
2005                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2006            )?;
2007
2008            let mut response = http_client.send(request).await?;
2009
2010            if let Some(minimum_required_version) = response
2011                .headers()
2012                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2013                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2014            {
2015                anyhow::ensure!(
2016                    app_version >= minimum_required_version,
2017                    ZedUpdateRequiredError {
2018                        minimum_version: minimum_required_version
2019                    }
2020                );
2021            }
2022
2023            if response.status().is_success() {
2024                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2025
2026                let mut body = Vec::new();
2027                response.body_mut().read_to_end(&mut body).await?;
2028                return Ok((serde_json::from_slice(&body)?, usage));
2029            } else if !did_retry
2030                && response
2031                    .headers()
2032                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2033                    .is_some()
2034            {
2035                did_retry = true;
2036                token = llm_token.refresh(&client).await?;
2037            } else {
2038                let mut body = String::new();
2039                response.body_mut().read_to_string(&mut body).await?;
2040                anyhow::bail!(
2041                    "Request failed with status: {:?}\nBody: {}",
2042                    response.status(),
2043                    body
2044                );
2045            }
2046        }
2047    }
2048
2049    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2050    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2051
2052    // Refresh the related excerpts when the user just beguns editing after
2053    // an idle period, and after they pause editing.
2054    fn refresh_context_if_needed(
2055        &mut self,
2056        project: &Entity<Project>,
2057        buffer: &Entity<language::Buffer>,
2058        cursor_position: language::Anchor,
2059        cx: &mut Context<Self>,
2060    ) {
2061        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2062            return;
2063        }
2064
2065        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2066            return;
2067        };
2068
2069        let now = Instant::now();
2070        let was_idle = zeta_project
2071            .refresh_context_timestamp
2072            .map_or(true, |timestamp| {
2073                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2074            });
2075        zeta_project.refresh_context_timestamp = Some(now);
2076        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2077            let buffer = buffer.clone();
2078            let project = project.clone();
2079            async move |this, cx| {
2080                if was_idle {
2081                    log::debug!("refetching edit prediction context after idle");
2082                } else {
2083                    cx.background_executor()
2084                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2085                        .await;
2086                    log::debug!("refetching edit prediction context after pause");
2087                }
2088                this.update(cx, |this, cx| {
2089                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2090
2091                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2092                        zeta_project.refresh_context_task = Some(task.log_err());
2093                    };
2094                })
2095                .ok()
2096            }
2097        }));
2098    }
2099
2100    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2101    // and avoid spawning more than one concurrent task.
2102    pub fn refresh_context(
2103        &mut self,
2104        project: Entity<Project>,
2105        buffer: Entity<language::Buffer>,
2106        cursor_position: language::Anchor,
2107        cx: &mut Context<Self>,
2108    ) -> Task<Result<()>> {
2109        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2110            return Task::ready(anyhow::Ok(()));
2111        };
2112
2113        let ContextMode::Agentic(options) = &self.options().context else {
2114            return Task::ready(anyhow::Ok(()));
2115        };
2116
2117        let snapshot = buffer.read(cx).snapshot();
2118        let cursor_point = cursor_position.to_point(&snapshot);
2119        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2120            cursor_point,
2121            &snapshot,
2122            &options.excerpt,
2123            None,
2124        ) else {
2125            return Task::ready(Ok(()));
2126        };
2127
2128        let app_version = AppVersion::global(cx);
2129        let client = self.client.clone();
2130        let llm_token = self.llm_token.clone();
2131        let debug_tx = self.debug_tx.clone();
2132        let current_file_path: Arc<Path> = snapshot
2133            .file()
2134            .map(|f| f.full_path(cx).into())
2135            .unwrap_or_else(|| Path::new("untitled").into());
2136
2137        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2138            predict_edits_v3::PlanContextRetrievalRequest {
2139                excerpt: cursor_excerpt.text(&snapshot).body,
2140                excerpt_path: current_file_path,
2141                excerpt_line_range: cursor_excerpt.line_range,
2142                cursor_file_max_row: Line(snapshot.max_point().row),
2143                events: zeta_project.events(cx),
2144            },
2145        ) {
2146            Ok(prompt) => prompt,
2147            Err(err) => {
2148                return Task::ready(Err(err));
2149            }
2150        };
2151
2152        if let Some(debug_tx) = &debug_tx {
2153            debug_tx
2154                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2155                    ZetaContextRetrievalStartedDebugInfo {
2156                        project: project.clone(),
2157                        timestamp: Instant::now(),
2158                        search_prompt: prompt.clone(),
2159                    },
2160                ))
2161                .ok();
2162        }
2163
2164        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2165            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2166                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2167            );
2168
2169            let description = schema
2170                .get("description")
2171                .and_then(|description| description.as_str())
2172                .unwrap()
2173                .to_string();
2174
2175            (schema.into(), description)
2176        });
2177
2178        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2179
2180        let request = open_ai::Request {
2181            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2182            messages: vec![open_ai::RequestMessage::User {
2183                content: open_ai::MessageContent::Plain(prompt),
2184            }],
2185            stream: false,
2186            max_completion_tokens: None,
2187            stop: Default::default(),
2188            temperature: 0.7,
2189            tool_choice: None,
2190            parallel_tool_calls: None,
2191            tools: vec![open_ai::ToolDefinition::Function {
2192                function: FunctionDefinition {
2193                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2194                    description: Some(tool_description),
2195                    parameters: Some(tool_schema),
2196                },
2197            }],
2198            prompt_cache_key: None,
2199            reasoning_effort: None,
2200        };
2201
2202        #[cfg(feature = "eval-support")]
2203        let eval_cache = self.eval_cache.clone();
2204
2205        cx.spawn(async move |this, cx| {
2206            log::trace!("Sending search planning request");
2207            let response = Self::send_raw_llm_request(
2208                request,
2209                client,
2210                llm_token,
2211                app_version,
2212                #[cfg(feature = "eval-support")]
2213                eval_cache.clone(),
2214                #[cfg(feature = "eval-support")]
2215                EvalCacheEntryKind::Context,
2216            )
2217            .await;
2218            let mut response = Self::handle_api_response(&this, response, cx)?;
2219            log::trace!("Got search planning response");
2220
2221            let choice = response
2222                .choices
2223                .pop()
2224                .context("No choices in retrieval response")?;
2225            let open_ai::RequestMessage::Assistant {
2226                content: _,
2227                tool_calls,
2228            } = choice.message
2229            else {
2230                anyhow::bail!("Retrieval response didn't include an assistant message");
2231            };
2232
2233            let mut queries: Vec<SearchToolQuery> = Vec::new();
2234            for tool_call in tool_calls {
2235                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2236                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2237                    log::warn!(
2238                        "Context retrieval response tried to call an unknown tool: {}",
2239                        function.name
2240                    );
2241
2242                    continue;
2243                }
2244
2245                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2246                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2247                queries.extend(input.queries);
2248            }
2249
2250            if let Some(debug_tx) = &debug_tx {
2251                debug_tx
2252                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2253                        ZetaSearchQueryDebugInfo {
2254                            project: project.clone(),
2255                            timestamp: Instant::now(),
2256                            search_queries: queries.clone(),
2257                        },
2258                    ))
2259                    .ok();
2260            }
2261
2262            log::trace!("Running retrieval search: {queries:#?}");
2263
2264            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2265                queries,
2266                project.clone(),
2267                #[cfg(feature = "eval-support")]
2268                eval_cache,
2269                cx,
2270            )
2271            .await;
2272
2273            log::trace!("Search queries executed");
2274
2275            if let Some(debug_tx) = &debug_tx {
2276                debug_tx
2277                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2278                        ZetaContextRetrievalDebugInfo {
2279                            project: project.clone(),
2280                            timestamp: Instant::now(),
2281                        },
2282                    ))
2283                    .ok();
2284            }
2285
2286            this.update(cx, |this, _cx| {
2287                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2288                    return Ok(());
2289                };
2290                zeta_project.refresh_context_task.take();
2291                if let Some(debug_tx) = &this.debug_tx {
2292                    debug_tx
2293                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2294                            ZetaContextRetrievalDebugInfo {
2295                                project,
2296                                timestamp: Instant::now(),
2297                            },
2298                        ))
2299                        .ok();
2300                }
2301                match related_excerpts_result {
2302                    Ok(excerpts) => {
2303                        zeta_project.context = Some(excerpts);
2304                        Ok(())
2305                    }
2306                    Err(error) => Err(error),
2307                }
2308            })?
2309        })
2310    }
2311
2312    pub fn set_context(
2313        &mut self,
2314        project: Entity<Project>,
2315        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2316    ) {
2317        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2318            zeta_project.context = Some(context);
2319        }
2320    }
2321
2322    fn gather_nearby_diagnostics(
2323        cursor_offset: usize,
2324        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2325        snapshot: &BufferSnapshot,
2326        max_diagnostics_bytes: usize,
2327    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2328        // TODO: Could make this more efficient
2329        let mut diagnostic_groups = Vec::new();
2330        for (language_server_id, diagnostics) in diagnostic_sets {
2331            let mut groups = Vec::new();
2332            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2333            diagnostic_groups.extend(
2334                groups
2335                    .into_iter()
2336                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2337            );
2338        }
2339
2340        // sort by proximity to cursor
2341        diagnostic_groups.sort_by_key(|group| {
2342            let range = &group.entries[group.primary_ix].range;
2343            if range.start >= cursor_offset {
2344                range.start - cursor_offset
2345            } else if cursor_offset >= range.end {
2346                cursor_offset - range.end
2347            } else {
2348                (cursor_offset - range.start).min(range.end - cursor_offset)
2349            }
2350        });
2351
2352        let mut results = Vec::new();
2353        let mut diagnostic_groups_truncated = false;
2354        let mut diagnostics_byte_count = 0;
2355        for group in diagnostic_groups {
2356            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2357            diagnostics_byte_count += raw_value.get().len();
2358            if diagnostics_byte_count > max_diagnostics_bytes {
2359                diagnostic_groups_truncated = true;
2360                break;
2361            }
2362            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2363        }
2364
2365        (results, diagnostic_groups_truncated)
2366    }
2367
2368    // TODO: Dedupe with similar code in request_prediction?
2369    pub fn cloud_request_for_zeta_cli(
2370        &mut self,
2371        project: &Entity<Project>,
2372        buffer: &Entity<Buffer>,
2373        position: language::Anchor,
2374        cx: &mut Context<Self>,
2375    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2376        let project_state = self.projects.get(&project.entity_id());
2377
2378        let index_state = project_state.and_then(|state| {
2379            state
2380                .syntax_index
2381                .as_ref()
2382                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2383        });
2384        let options = self.options.clone();
2385        let snapshot = buffer.read(cx).snapshot();
2386        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2387            return Task::ready(Err(anyhow!("No file path for excerpt")));
2388        };
2389        let worktree_snapshots = project
2390            .read(cx)
2391            .worktrees(cx)
2392            .map(|worktree| worktree.read(cx).snapshot())
2393            .collect::<Vec<_>>();
2394
2395        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2396            let mut path = f.worktree.read(cx).absolutize(&f.path);
2397            if path.pop() { Some(path) } else { None }
2398        });
2399
2400        cx.background_spawn(async move {
2401            let index_state = if let Some(index_state) = index_state {
2402                Some(index_state.lock_owned().await)
2403            } else {
2404                None
2405            };
2406
2407            let cursor_point = position.to_point(&snapshot);
2408
2409            let debug_info = true;
2410            EditPredictionContext::gather_context(
2411                cursor_point,
2412                &snapshot,
2413                parent_abs_path.as_deref(),
2414                match &options.context {
2415                    ContextMode::Agentic(_) => {
2416                        // TODO
2417                        panic!("Llm mode not supported in zeta cli yet");
2418                    }
2419                    ContextMode::Syntax(edit_prediction_context_options) => {
2420                        edit_prediction_context_options
2421                    }
2422                },
2423                index_state.as_deref(),
2424            )
2425            .context("Failed to select excerpt")
2426            .map(|context| {
2427                make_syntax_context_cloud_request(
2428                    excerpt_path.into(),
2429                    context,
2430                    // TODO pass everything
2431                    Vec::new(),
2432                    false,
2433                    Vec::new(),
2434                    false,
2435                    None,
2436                    debug_info,
2437                    &worktree_snapshots,
2438                    index_state.as_deref(),
2439                    Some(options.max_prompt_bytes),
2440                    options.prompt_format,
2441                    PredictEditsRequestTrigger::Other,
2442                )
2443            })
2444        })
2445    }
2446
2447    pub fn wait_for_initial_indexing(
2448        &mut self,
2449        project: &Entity<Project>,
2450        cx: &mut Context<Self>,
2451    ) -> Task<Result<()>> {
2452        let zeta_project = self.get_or_init_zeta_project(project, cx);
2453        if let Some(syntax_index) = &zeta_project.syntax_index {
2454            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2455        } else {
2456            Task::ready(Ok(()))
2457        }
2458    }
2459
2460    fn is_file_open_source(
2461        &self,
2462        project: &Entity<Project>,
2463        file: &Arc<dyn File>,
2464        cx: &App,
2465    ) -> bool {
2466        if !file.is_local() || file.is_private() {
2467            return false;
2468        }
2469        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2470            return false;
2471        };
2472        zeta_project
2473            .license_detection_watchers
2474            .get(&file.worktree_id(cx))
2475            .as_ref()
2476            .is_some_and(|watcher| watcher.is_project_open_source())
2477    }
2478
2479    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2480        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2481    }
2482
2483    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2484        if !self.data_collection_choice.is_enabled() {
2485            return false;
2486        }
2487        events.iter().all(|event| {
2488            matches!(
2489                event.as_ref(),
2490                Event::BufferChange {
2491                    in_open_source_repo: true,
2492                    ..
2493                }
2494            )
2495        })
2496    }
2497
2498    fn load_data_collection_choice() -> DataCollectionChoice {
2499        let choice = KEY_VALUE_STORE
2500            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2501            .log_err()
2502            .flatten();
2503
2504        match choice.as_deref() {
2505            Some("true") => DataCollectionChoice::Enabled,
2506            Some("false") => DataCollectionChoice::Disabled,
2507            Some(_) => {
2508                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2509                DataCollectionChoice::NotAnswered
2510            }
2511            None => DataCollectionChoice::NotAnswered,
2512        }
2513    }
2514
2515    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2516        self.shown_predictions.iter()
2517    }
2518
2519    pub fn shown_completions_len(&self) -> usize {
2520        self.shown_predictions.len()
2521    }
2522
2523    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2524        self.rated_predictions.contains(id)
2525    }
2526
2527    pub fn rate_prediction(
2528        &mut self,
2529        prediction: &EditPrediction,
2530        rating: EditPredictionRating,
2531        feedback: String,
2532        cx: &mut Context<Self>,
2533    ) {
2534        self.rated_predictions.insert(prediction.id.clone());
2535        telemetry::event!(
2536            "Edit Prediction Rated",
2537            rating,
2538            inputs = prediction.inputs,
2539            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2540            feedback
2541        );
2542        self.client.telemetry().flush_events().detach();
2543        cx.notify();
2544    }
2545}
2546
2547pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2548    let choice = res.choices.pop()?;
2549    let output_text = match choice.message {
2550        open_ai::RequestMessage::Assistant {
2551            content: Some(open_ai::MessageContent::Plain(content)),
2552            ..
2553        } => content,
2554        open_ai::RequestMessage::Assistant {
2555            content: Some(open_ai::MessageContent::Multipart(mut content)),
2556            ..
2557        } => {
2558            if content.is_empty() {
2559                log::error!("No output from Baseten completion response");
2560                return None;
2561            }
2562
2563            match content.remove(0) {
2564                open_ai::MessagePart::Text { text } => text,
2565                open_ai::MessagePart::Image { .. } => {
2566                    log::error!("Expected text, got an image");
2567                    return None;
2568                }
2569            }
2570        }
2571        _ => {
2572            log::error!("Invalid response message: {:?}", choice.message);
2573            return None;
2574        }
2575    };
2576    Some(output_text)
2577}
2578
2579#[derive(Error, Debug)]
2580#[error(
2581    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2582)]
2583pub struct ZedUpdateRequiredError {
2584    minimum_version: Version,
2585}
2586
2587fn make_syntax_context_cloud_request(
2588    excerpt_path: Arc<Path>,
2589    context: EditPredictionContext,
2590    events: Vec<Arc<predict_edits_v3::Event>>,
2591    can_collect_data: bool,
2592    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2593    diagnostic_groups_truncated: bool,
2594    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2595    debug_info: bool,
2596    worktrees: &Vec<worktree::Snapshot>,
2597    index_state: Option<&SyntaxIndexState>,
2598    prompt_max_bytes: Option<usize>,
2599    prompt_format: PromptFormat,
2600    trigger: PredictEditsRequestTrigger,
2601) -> predict_edits_v3::PredictEditsRequest {
2602    let mut signatures = Vec::new();
2603    let mut declaration_to_signature_index = HashMap::default();
2604    let mut referenced_declarations = Vec::new();
2605
2606    for snippet in context.declarations {
2607        let project_entry_id = snippet.declaration.project_entry_id();
2608        let Some(path) = worktrees.iter().find_map(|worktree| {
2609            worktree.entry_for_id(project_entry_id).map(|entry| {
2610                let mut full_path = RelPathBuf::new();
2611                full_path.push(worktree.root_name());
2612                full_path.push(&entry.path);
2613                full_path
2614            })
2615        }) else {
2616            continue;
2617        };
2618
2619        let parent_index = index_state.and_then(|index_state| {
2620            snippet.declaration.parent().and_then(|parent| {
2621                add_signature(
2622                    parent,
2623                    &mut declaration_to_signature_index,
2624                    &mut signatures,
2625                    index_state,
2626                )
2627            })
2628        });
2629
2630        let (text, text_is_truncated) = snippet.declaration.item_text();
2631        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2632            path: path.as_std_path().into(),
2633            text: text.into(),
2634            range: snippet.declaration.item_line_range(),
2635            text_is_truncated,
2636            signature_range: snippet.declaration.signature_range_in_item_text(),
2637            parent_index,
2638            signature_score: snippet.score(DeclarationStyle::Signature),
2639            declaration_score: snippet.score(DeclarationStyle::Declaration),
2640            score_components: snippet.components,
2641        });
2642    }
2643
2644    let excerpt_parent = index_state.and_then(|index_state| {
2645        context
2646            .excerpt
2647            .parent_declarations
2648            .last()
2649            .and_then(|(parent, _)| {
2650                add_signature(
2651                    *parent,
2652                    &mut declaration_to_signature_index,
2653                    &mut signatures,
2654                    index_state,
2655                )
2656            })
2657    });
2658
2659    predict_edits_v3::PredictEditsRequest {
2660        excerpt_path,
2661        excerpt: context.excerpt_text.body,
2662        excerpt_line_range: context.excerpt.line_range,
2663        excerpt_range: context.excerpt.range,
2664        cursor_point: predict_edits_v3::Point {
2665            line: predict_edits_v3::Line(context.cursor_point.row),
2666            column: context.cursor_point.column,
2667        },
2668        referenced_declarations,
2669        included_files: vec![],
2670        signatures,
2671        excerpt_parent,
2672        events,
2673        can_collect_data,
2674        diagnostic_groups,
2675        diagnostic_groups_truncated,
2676        git_info,
2677        debug_info,
2678        prompt_max_bytes,
2679        prompt_format,
2680        trigger,
2681    }
2682}
2683
2684fn add_signature(
2685    declaration_id: DeclarationId,
2686    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2687    signatures: &mut Vec<Signature>,
2688    index: &SyntaxIndexState,
2689) -> Option<usize> {
2690    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2691        return Some(*signature_index);
2692    }
2693    let Some(parent_declaration) = index.declaration(declaration_id) else {
2694        log::error!("bug: missing parent declaration");
2695        return None;
2696    };
2697    let parent_index = parent_declaration.parent().and_then(|parent| {
2698        add_signature(parent, declaration_to_signature_index, signatures, index)
2699    });
2700    let (text, text_is_truncated) = parent_declaration.signature_text();
2701    let signature_index = signatures.len();
2702    signatures.push(Signature {
2703        text: text.into(),
2704        text_is_truncated,
2705        parent_index,
2706        range: parent_declaration.signature_line_range(),
2707    });
2708    declaration_to_signature_index.insert(declaration_id, signature_index);
2709    Some(signature_index)
2710}
2711
2712#[cfg(feature = "eval-support")]
2713pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2714
2715#[cfg(feature = "eval-support")]
2716#[derive(Debug, Clone, Copy, PartialEq)]
2717pub enum EvalCacheEntryKind {
2718    Context,
2719    Search,
2720    Prediction,
2721}
2722
2723#[cfg(feature = "eval-support")]
2724impl std::fmt::Display for EvalCacheEntryKind {
2725    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2726        match self {
2727            EvalCacheEntryKind::Search => write!(f, "search"),
2728            EvalCacheEntryKind::Context => write!(f, "context"),
2729            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2730        }
2731    }
2732}
2733
2734#[cfg(feature = "eval-support")]
2735pub trait EvalCache: Send + Sync {
2736    fn read(&self, key: EvalCacheKey) -> Option<String>;
2737    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2738}
2739
2740#[derive(Debug, Clone, Copy)]
2741pub enum DataCollectionChoice {
2742    NotAnswered,
2743    Enabled,
2744    Disabled,
2745}
2746
2747impl DataCollectionChoice {
2748    pub fn is_enabled(self) -> bool {
2749        match self {
2750            Self::Enabled => true,
2751            Self::NotAnswered | Self::Disabled => false,
2752        }
2753    }
2754
2755    pub fn is_answered(self) -> bool {
2756        match self {
2757            Self::Enabled | Self::Disabled => true,
2758            Self::NotAnswered => false,
2759        }
2760    }
2761
2762    #[must_use]
2763    pub fn toggle(&self) -> DataCollectionChoice {
2764        match self {
2765            Self::Enabled => Self::Disabled,
2766            Self::Disabled => Self::Enabled,
2767            Self::NotAnswered => Self::Enabled,
2768        }
2769    }
2770}
2771
2772impl From<bool> for DataCollectionChoice {
2773    fn from(value: bool) -> Self {
2774        match value {
2775            true => DataCollectionChoice::Enabled,
2776            false => DataCollectionChoice::Disabled,
2777        }
2778    }
2779}
2780
2781struct ZedPredictUpsell;
2782
2783impl Dismissable for ZedPredictUpsell {
2784    const KEY: &'static str = "dismissed-edit-predict-upsell";
2785
2786    fn dismissed() -> bool {
2787        // To make this backwards compatible with older versions of Zed, we
2788        // check if the user has seen the previous Edit Prediction Onboarding
2789        // before, by checking the data collection choice which was written to
2790        // the database once the user clicked on "Accept and Enable"
2791        if KEY_VALUE_STORE
2792            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2793            .log_err()
2794            .is_some_and(|s| s.is_some())
2795        {
2796            return true;
2797        }
2798
2799        KEY_VALUE_STORE
2800            .read_kvp(Self::KEY)
2801            .log_err()
2802            .is_some_and(|s| s.is_some())
2803    }
2804}
2805
2806pub fn should_show_upsell_modal() -> bool {
2807    !ZedPredictUpsell::dismissed()
2808}
2809
2810pub fn init(cx: &mut App) {
2811    feature_gate_predict_edits_actions(cx);
2812
2813    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2814        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2815            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2816                RatePredictionsModal::toggle(workspace, window, cx);
2817            }
2818        });
2819
2820        workspace.register_action(
2821            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2822                ZedPredictModal::toggle(
2823                    workspace,
2824                    workspace.user_store().clone(),
2825                    workspace.client().clone(),
2826                    window,
2827                    cx,
2828                )
2829            },
2830        );
2831
2832        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2833            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2834                settings
2835                    .project
2836                    .all_languages
2837                    .features
2838                    .get_or_insert_default()
2839                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2840            });
2841        });
2842    })
2843    .detach();
2844}
2845
2846fn feature_gate_predict_edits_actions(cx: &mut App) {
2847    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2848    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2849    let zeta_all_action_types = [
2850        TypeId::of::<RateCompletions>(),
2851        TypeId::of::<ResetOnboarding>(),
2852        zed_actions::OpenZedPredictOnboarding.type_id(),
2853        TypeId::of::<ClearHistory>(),
2854        TypeId::of::<ThumbsUpActivePrediction>(),
2855        TypeId::of::<ThumbsDownActivePrediction>(),
2856        TypeId::of::<NextEdit>(),
2857        TypeId::of::<PreviousEdit>(),
2858    ];
2859
2860    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2861        filter.hide_action_types(&rate_completion_action_types);
2862        filter.hide_action_types(&reset_onboarding_action_types);
2863        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2864    });
2865
2866    cx.observe_global::<SettingsStore>(move |cx| {
2867        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2868        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2869
2870        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2871            if is_ai_disabled {
2872                filter.hide_action_types(&zeta_all_action_types);
2873            } else if has_feature_flag {
2874                filter.show_action_types(&rate_completion_action_types);
2875            } else {
2876                filter.hide_action_types(&rate_completion_action_types);
2877            }
2878        });
2879    })
2880    .detach();
2881
2882    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2883        if !DisableAiSettings::get_global(cx).disable_ai {
2884            if is_enabled {
2885                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2886                    filter.show_action_types(&rate_completion_action_types);
2887                });
2888            } else {
2889                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2890                    filter.hide_action_types(&rate_completion_action_types);
2891                });
2892            }
2893        }
2894    })
2895    .detach();
2896}
2897
2898#[cfg(test)]
2899mod tests {
2900    use std::{path::Path, sync::Arc};
2901
2902    use client::UserStore;
2903    use clock::FakeSystemClock;
2904    use cloud_llm_client::{
2905        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2906    };
2907    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2908    use futures::{
2909        AsyncReadExt, StreamExt,
2910        channel::{mpsc, oneshot},
2911    };
2912    use gpui::{
2913        Entity, TestAppContext,
2914        http_client::{FakeHttpClient, Response},
2915        prelude::*,
2916    };
2917    use indoc::indoc;
2918    use language::OffsetRangeExt as _;
2919    use open_ai::Usage;
2920    use pretty_assertions::{assert_eq, assert_matches};
2921    use project::{FakeFs, Project};
2922    use serde_json::json;
2923    use settings::SettingsStore;
2924    use util::path;
2925    use uuid::Uuid;
2926
2927    use crate::{BufferEditPrediction, Zeta};
2928
2929    #[gpui::test]
2930    async fn test_current_state(cx: &mut TestAppContext) {
2931        let (zeta, mut requests) = init_test(cx);
2932        let fs = FakeFs::new(cx.executor());
2933        fs.insert_tree(
2934            "/root",
2935            json!({
2936                "1.txt": "Hello!\nHow\nBye\n",
2937                "2.txt": "Hola!\nComo\nAdios\n"
2938            }),
2939        )
2940        .await;
2941        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2942
2943        zeta.update(cx, |zeta, cx| {
2944            zeta.register_project(&project, cx);
2945        });
2946
2947        let buffer1 = project
2948            .update(cx, |project, cx| {
2949                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2950                project.open_buffer(path, cx)
2951            })
2952            .await
2953            .unwrap();
2954        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2955        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2956
2957        // Prediction for current file
2958
2959        zeta.update(cx, |zeta, cx| {
2960            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2961        });
2962        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2963
2964        respond_tx
2965            .send(model_response(indoc! {r"
2966                --- a/root/1.txt
2967                +++ b/root/1.txt
2968                @@ ... @@
2969                 Hello!
2970                -How
2971                +How are you?
2972                 Bye
2973            "}))
2974            .unwrap();
2975
2976        cx.run_until_parked();
2977
2978        zeta.read_with(cx, |zeta, cx| {
2979            let prediction = zeta
2980                .current_prediction_for_buffer(&buffer1, &project, cx)
2981                .unwrap();
2982            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2983        });
2984
2985        // Context refresh
2986        let refresh_task = zeta.update(cx, |zeta, cx| {
2987            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2988        });
2989        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2990        respond_tx
2991            .send(open_ai::Response {
2992                id: Uuid::new_v4().to_string(),
2993                object: "response".into(),
2994                created: 0,
2995                model: "model".into(),
2996                choices: vec![open_ai::Choice {
2997                    index: 0,
2998                    message: open_ai::RequestMessage::Assistant {
2999                        content: None,
3000                        tool_calls: vec![open_ai::ToolCall {
3001                            id: "search".into(),
3002                            content: open_ai::ToolCallContent::Function {
3003                                function: open_ai::FunctionContent {
3004                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3005                                        .to_string(),
3006                                    arguments: serde_json::to_string(&SearchToolInput {
3007                                        queries: Box::new([SearchToolQuery {
3008                                            glob: "root/2.txt".to_string(),
3009                                            syntax_node: vec![],
3010                                            content: Some(".".into()),
3011                                        }]),
3012                                    })
3013                                    .unwrap(),
3014                                },
3015                            },
3016                        }],
3017                    },
3018                    finish_reason: None,
3019                }],
3020                usage: Usage {
3021                    prompt_tokens: 0,
3022                    completion_tokens: 0,
3023                    total_tokens: 0,
3024                },
3025            })
3026            .unwrap();
3027        refresh_task.await.unwrap();
3028
3029        zeta.update(cx, |zeta, cx| {
3030            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3031        });
3032
3033        // Prediction for another file
3034        zeta.update(cx, |zeta, cx| {
3035            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3036        });
3037        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3038        respond_tx
3039            .send(model_response(indoc! {r#"
3040                --- a/root/2.txt
3041                +++ b/root/2.txt
3042                 Hola!
3043                -Como
3044                +Como estas?
3045                 Adios
3046            "#}))
3047            .unwrap();
3048        cx.run_until_parked();
3049
3050        zeta.read_with(cx, |zeta, cx| {
3051            let prediction = zeta
3052                .current_prediction_for_buffer(&buffer1, &project, cx)
3053                .unwrap();
3054            assert_matches!(
3055                prediction,
3056                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3057            );
3058        });
3059
3060        let buffer2 = project
3061            .update(cx, |project, cx| {
3062                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3063                project.open_buffer(path, cx)
3064            })
3065            .await
3066            .unwrap();
3067
3068        zeta.read_with(cx, |zeta, cx| {
3069            let prediction = zeta
3070                .current_prediction_for_buffer(&buffer2, &project, cx)
3071                .unwrap();
3072            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3073        });
3074    }
3075
3076    #[gpui::test]
3077    async fn test_simple_request(cx: &mut TestAppContext) {
3078        let (zeta, mut requests) = init_test(cx);
3079        let fs = FakeFs::new(cx.executor());
3080        fs.insert_tree(
3081            "/root",
3082            json!({
3083                "foo.md":  "Hello!\nHow\nBye\n"
3084            }),
3085        )
3086        .await;
3087        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3088
3089        let buffer = project
3090            .update(cx, |project, cx| {
3091                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3092                project.open_buffer(path, cx)
3093            })
3094            .await
3095            .unwrap();
3096        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3097        let position = snapshot.anchor_before(language::Point::new(1, 3));
3098
3099        let prediction_task = zeta.update(cx, |zeta, cx| {
3100            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3101        });
3102
3103        let (_, respond_tx) = requests.predict.next().await.unwrap();
3104
3105        // TODO Put back when we have a structured request again
3106        // assert_eq!(
3107        //     request.excerpt_path.as_ref(),
3108        //     Path::new(path!("root/foo.md"))
3109        // );
3110        // assert_eq!(
3111        //     request.cursor_point,
3112        //     Point {
3113        //         line: Line(1),
3114        //         column: 3
3115        //     }
3116        // );
3117
3118        respond_tx
3119            .send(model_response(indoc! { r"
3120                --- a/root/foo.md
3121                +++ b/root/foo.md
3122                @@ ... @@
3123                 Hello!
3124                -How
3125                +How are you?
3126                 Bye
3127            "}))
3128            .unwrap();
3129
3130        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3131
3132        assert_eq!(prediction.edits.len(), 1);
3133        assert_eq!(
3134            prediction.edits[0].0.to_point(&snapshot).start,
3135            language::Point::new(1, 3)
3136        );
3137        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3138    }
3139
3140    #[gpui::test]
3141    async fn test_request_events(cx: &mut TestAppContext) {
3142        let (zeta, mut requests) = init_test(cx);
3143        let fs = FakeFs::new(cx.executor());
3144        fs.insert_tree(
3145            "/root",
3146            json!({
3147                "foo.md": "Hello!\n\nBye\n"
3148            }),
3149        )
3150        .await;
3151        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3152
3153        let buffer = project
3154            .update(cx, |project, cx| {
3155                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3156                project.open_buffer(path, cx)
3157            })
3158            .await
3159            .unwrap();
3160
3161        zeta.update(cx, |zeta, cx| {
3162            zeta.register_buffer(&buffer, &project, cx);
3163        });
3164
3165        buffer.update(cx, |buffer, cx| {
3166            buffer.edit(vec![(7..7, "How")], None, cx);
3167        });
3168
3169        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3170        let position = snapshot.anchor_before(language::Point::new(1, 3));
3171
3172        let prediction_task = zeta.update(cx, |zeta, cx| {
3173            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3174        });
3175
3176        let (request, respond_tx) = requests.predict.next().await.unwrap();
3177
3178        let prompt = prompt_from_request(&request);
3179        assert!(
3180            prompt.contains(indoc! {"
3181            --- a/root/foo.md
3182            +++ b/root/foo.md
3183            @@ -1,3 +1,3 @@
3184             Hello!
3185            -
3186            +How
3187             Bye
3188        "}),
3189            "{prompt}"
3190        );
3191
3192        respond_tx
3193            .send(model_response(indoc! {r#"
3194                --- a/root/foo.md
3195                +++ b/root/foo.md
3196                @@ ... @@
3197                 Hello!
3198                -How
3199                +How are you?
3200                 Bye
3201            "#}))
3202            .unwrap();
3203
3204        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3205
3206        assert_eq!(prediction.edits.len(), 1);
3207        assert_eq!(
3208            prediction.edits[0].0.to_point(&snapshot).start,
3209            language::Point::new(1, 3)
3210        );
3211        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3212    }
3213
3214    #[gpui::test]
3215    async fn test_empty_prediction(cx: &mut TestAppContext) {
3216        let (zeta, mut requests) = init_test(cx);
3217        let fs = FakeFs::new(cx.executor());
3218        fs.insert_tree(
3219            "/root",
3220            json!({
3221                "foo.md":  "Hello!\nHow\nBye\n"
3222            }),
3223        )
3224        .await;
3225        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3226
3227        let buffer = project
3228            .update(cx, |project, cx| {
3229                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3230                project.open_buffer(path, cx)
3231            })
3232            .await
3233            .unwrap();
3234        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3235        let position = snapshot.anchor_before(language::Point::new(1, 3));
3236
3237        zeta.update(cx, |zeta, cx| {
3238            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3239        });
3240
3241        const NO_OP_DIFF: &str = indoc! { r"
3242            --- a/root/foo.md
3243            +++ b/root/foo.md
3244            @@ ... @@
3245             Hello!
3246            -How
3247            +How
3248             Bye
3249        "};
3250
3251        let (_, respond_tx) = requests.predict.next().await.unwrap();
3252        let response = model_response(NO_OP_DIFF);
3253        let id = response.id.clone();
3254        respond_tx.send(response).unwrap();
3255
3256        cx.run_until_parked();
3257
3258        zeta.read_with(cx, |zeta, cx| {
3259            assert!(
3260                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3261                    .is_none()
3262            );
3263        });
3264
3265        // prediction is reported as rejected
3266        let (reject_request, _) = requests.reject.next().await.unwrap();
3267
3268        assert_eq!(
3269            &reject_request.rejections,
3270            &[EditPredictionRejection {
3271                request_id: id,
3272                reason: EditPredictionRejectReason::Empty,
3273                was_shown: false
3274            }]
3275        );
3276    }
3277
3278    #[gpui::test]
3279    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3280        let (zeta, mut requests) = init_test(cx);
3281        let fs = FakeFs::new(cx.executor());
3282        fs.insert_tree(
3283            "/root",
3284            json!({
3285                "foo.md":  "Hello!\nHow\nBye\n"
3286            }),
3287        )
3288        .await;
3289        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3290
3291        let buffer = project
3292            .update(cx, |project, cx| {
3293                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3294                project.open_buffer(path, cx)
3295            })
3296            .await
3297            .unwrap();
3298        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3299        let position = snapshot.anchor_before(language::Point::new(1, 3));
3300
3301        zeta.update(cx, |zeta, cx| {
3302            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3303        });
3304
3305        let (_, respond_tx) = requests.predict.next().await.unwrap();
3306
3307        buffer.update(cx, |buffer, cx| {
3308            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3309        });
3310
3311        let response = model_response(SIMPLE_DIFF);
3312        let id = response.id.clone();
3313        respond_tx.send(response).unwrap();
3314
3315        cx.run_until_parked();
3316
3317        zeta.read_with(cx, |zeta, cx| {
3318            assert!(
3319                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3320                    .is_none()
3321            );
3322        });
3323
3324        // prediction is reported as rejected
3325        let (reject_request, _) = requests.reject.next().await.unwrap();
3326
3327        assert_eq!(
3328            &reject_request.rejections,
3329            &[EditPredictionRejection {
3330                request_id: id,
3331                reason: EditPredictionRejectReason::InterpolatedEmpty,
3332                was_shown: false
3333            }]
3334        );
3335    }
3336
3337    const SIMPLE_DIFF: &str = indoc! { r"
3338        --- a/root/foo.md
3339        +++ b/root/foo.md
3340        @@ ... @@
3341         Hello!
3342        -How
3343        +How are you?
3344         Bye
3345    "};
3346
3347    #[gpui::test]
3348    async fn test_replace_current(cx: &mut TestAppContext) {
3349        let (zeta, mut requests) = init_test(cx);
3350        let fs = FakeFs::new(cx.executor());
3351        fs.insert_tree(
3352            "/root",
3353            json!({
3354                "foo.md":  "Hello!\nHow\nBye\n"
3355            }),
3356        )
3357        .await;
3358        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3359
3360        let buffer = project
3361            .update(cx, |project, cx| {
3362                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3363                project.open_buffer(path, cx)
3364            })
3365            .await
3366            .unwrap();
3367        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3368        let position = snapshot.anchor_before(language::Point::new(1, 3));
3369
3370        zeta.update(cx, |zeta, cx| {
3371            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3372        });
3373
3374        let (_, respond_tx) = requests.predict.next().await.unwrap();
3375        let first_response = model_response(SIMPLE_DIFF);
3376        let first_id = first_response.id.clone();
3377        respond_tx.send(first_response).unwrap();
3378
3379        cx.run_until_parked();
3380
3381        zeta.read_with(cx, |zeta, cx| {
3382            assert_eq!(
3383                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3384                    .unwrap()
3385                    .id
3386                    .0,
3387                first_id
3388            );
3389        });
3390
3391        // a second request is triggered
3392        zeta.update(cx, |zeta, cx| {
3393            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3394        });
3395
3396        let (_, respond_tx) = requests.predict.next().await.unwrap();
3397        let second_response = model_response(SIMPLE_DIFF);
3398        let second_id = second_response.id.clone();
3399        respond_tx.send(second_response).unwrap();
3400
3401        cx.run_until_parked();
3402
3403        zeta.read_with(cx, |zeta, cx| {
3404            // second replaces first
3405            assert_eq!(
3406                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3407                    .unwrap()
3408                    .id
3409                    .0,
3410                second_id
3411            );
3412        });
3413
3414        // first is reported as replaced
3415        let (reject_request, _) = requests.reject.next().await.unwrap();
3416
3417        assert_eq!(
3418            &reject_request.rejections,
3419            &[EditPredictionRejection {
3420                request_id: first_id,
3421                reason: EditPredictionRejectReason::Replaced,
3422                was_shown: false
3423            }]
3424        );
3425    }
3426
3427    #[gpui::test]
3428    async fn test_current_preferred(cx: &mut TestAppContext) {
3429        let (zeta, mut requests) = init_test(cx);
3430        let fs = FakeFs::new(cx.executor());
3431        fs.insert_tree(
3432            "/root",
3433            json!({
3434                "foo.md":  "Hello!\nHow\nBye\n"
3435            }),
3436        )
3437        .await;
3438        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3439
3440        let buffer = project
3441            .update(cx, |project, cx| {
3442                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3443                project.open_buffer(path, cx)
3444            })
3445            .await
3446            .unwrap();
3447        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3448        let position = snapshot.anchor_before(language::Point::new(1, 3));
3449
3450        zeta.update(cx, |zeta, cx| {
3451            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3452        });
3453
3454        let (_, respond_tx) = requests.predict.next().await.unwrap();
3455        let first_response = model_response(SIMPLE_DIFF);
3456        let first_id = first_response.id.clone();
3457        respond_tx.send(first_response).unwrap();
3458
3459        cx.run_until_parked();
3460
3461        zeta.read_with(cx, |zeta, cx| {
3462            assert_eq!(
3463                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3464                    .unwrap()
3465                    .id
3466                    .0,
3467                first_id
3468            );
3469        });
3470
3471        // a second request is triggered
3472        zeta.update(cx, |zeta, cx| {
3473            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3474        });
3475
3476        let (_, respond_tx) = requests.predict.next().await.unwrap();
3477        // worse than current prediction
3478        let second_response = model_response(indoc! { r"
3479            --- a/root/foo.md
3480            +++ b/root/foo.md
3481            @@ ... @@
3482             Hello!
3483            -How
3484            +How are
3485             Bye
3486        "});
3487        let second_id = second_response.id.clone();
3488        respond_tx.send(second_response).unwrap();
3489
3490        cx.run_until_parked();
3491
3492        zeta.read_with(cx, |zeta, cx| {
3493            // first is preferred over second
3494            assert_eq!(
3495                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3496                    .unwrap()
3497                    .id
3498                    .0,
3499                first_id
3500            );
3501        });
3502
3503        // second is reported as rejected
3504        let (reject_request, _) = requests.reject.next().await.unwrap();
3505
3506        assert_eq!(
3507            &reject_request.rejections,
3508            &[EditPredictionRejection {
3509                request_id: second_id,
3510                reason: EditPredictionRejectReason::CurrentPreferred,
3511                was_shown: false
3512            }]
3513        );
3514    }
3515
3516    #[gpui::test]
3517    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3518        let (zeta, mut requests) = init_test(cx);
3519        let fs = FakeFs::new(cx.executor());
3520        fs.insert_tree(
3521            "/root",
3522            json!({
3523                "foo.md":  "Hello!\nHow\nBye\n"
3524            }),
3525        )
3526        .await;
3527        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3528
3529        let buffer = project
3530            .update(cx, |project, cx| {
3531                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3532                project.open_buffer(path, cx)
3533            })
3534            .await
3535            .unwrap();
3536        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3537        let position = snapshot.anchor_before(language::Point::new(1, 3));
3538
3539        zeta.update(cx, |zeta, cx| {
3540            // start two refresh tasks
3541            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3542
3543            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3544        });
3545
3546        let (_, respond_first) = requests.predict.next().await.unwrap();
3547        let (_, respond_second) = requests.predict.next().await.unwrap();
3548
3549        // wait for throttle
3550        cx.run_until_parked();
3551
3552        // second responds first
3553        let second_response = model_response(SIMPLE_DIFF);
3554        let second_id = second_response.id.clone();
3555        respond_second.send(second_response).unwrap();
3556
3557        cx.run_until_parked();
3558
3559        zeta.read_with(cx, |zeta, cx| {
3560            // current prediction is second
3561            assert_eq!(
3562                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3563                    .unwrap()
3564                    .id
3565                    .0,
3566                second_id
3567            );
3568        });
3569
3570        let first_response = model_response(SIMPLE_DIFF);
3571        let first_id = first_response.id.clone();
3572        respond_first.send(first_response).unwrap();
3573
3574        cx.run_until_parked();
3575
3576        zeta.read_with(cx, |zeta, cx| {
3577            // current prediction is still second, since first was cancelled
3578            assert_eq!(
3579                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3580                    .unwrap()
3581                    .id
3582                    .0,
3583                second_id
3584            );
3585        });
3586
3587        // first is reported as rejected
3588        let (reject_request, _) = requests.reject.next().await.unwrap();
3589
3590        cx.run_until_parked();
3591
3592        assert_eq!(
3593            &reject_request.rejections,
3594            &[EditPredictionRejection {
3595                request_id: first_id,
3596                reason: EditPredictionRejectReason::Canceled,
3597                was_shown: false
3598            }]
3599        );
3600    }
3601
3602    #[gpui::test]
3603    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3604        let (zeta, mut requests) = init_test(cx);
3605        let fs = FakeFs::new(cx.executor());
3606        fs.insert_tree(
3607            "/root",
3608            json!({
3609                "foo.md":  "Hello!\nHow\nBye\n"
3610            }),
3611        )
3612        .await;
3613        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3614
3615        let buffer = project
3616            .update(cx, |project, cx| {
3617                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3618                project.open_buffer(path, cx)
3619            })
3620            .await
3621            .unwrap();
3622        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3623        let position = snapshot.anchor_before(language::Point::new(1, 3));
3624
3625        zeta.update(cx, |zeta, cx| {
3626            // start two refresh tasks
3627            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3628            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3629        });
3630
3631        // wait for throttle, so requests are sent
3632        cx.run_until_parked();
3633
3634        let (_, respond_first) = requests.predict.next().await.unwrap();
3635        let (_, respond_second) = requests.predict.next().await.unwrap();
3636
3637        zeta.update(cx, |zeta, cx| {
3638            // start a third request
3639            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3640
3641            // 2 are pending, so 2nd is cancelled
3642            assert_eq!(
3643                zeta.get_or_init_zeta_project(&project, cx)
3644                    .cancelled_predictions
3645                    .iter()
3646                    .copied()
3647                    .collect::<Vec<_>>(),
3648                [1]
3649            );
3650        });
3651
3652        // wait for throttle
3653        cx.run_until_parked();
3654
3655        let (_, respond_third) = requests.predict.next().await.unwrap();
3656
3657        let first_response = model_response(SIMPLE_DIFF);
3658        let first_id = first_response.id.clone();
3659        respond_first.send(first_response).unwrap();
3660
3661        cx.run_until_parked();
3662
3663        zeta.read_with(cx, |zeta, cx| {
3664            // current prediction is first
3665            assert_eq!(
3666                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3667                    .unwrap()
3668                    .id
3669                    .0,
3670                first_id
3671            );
3672        });
3673
3674        let cancelled_response = model_response(SIMPLE_DIFF);
3675        let cancelled_id = cancelled_response.id.clone();
3676        respond_second.send(cancelled_response).unwrap();
3677
3678        cx.run_until_parked();
3679
3680        zeta.read_with(cx, |zeta, cx| {
3681            // current prediction is still first, since second was cancelled
3682            assert_eq!(
3683                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3684                    .unwrap()
3685                    .id
3686                    .0,
3687                first_id
3688            );
3689        });
3690
3691        let third_response = model_response(SIMPLE_DIFF);
3692        let third_response_id = third_response.id.clone();
3693        respond_third.send(third_response).unwrap();
3694
3695        cx.run_until_parked();
3696
3697        zeta.read_with(cx, |zeta, cx| {
3698            // third completes and replaces first
3699            assert_eq!(
3700                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3701                    .unwrap()
3702                    .id
3703                    .0,
3704                third_response_id
3705            );
3706        });
3707
3708        // second is reported as rejected
3709        let (reject_request, _) = requests.reject.next().await.unwrap();
3710
3711        cx.run_until_parked();
3712
3713        assert_eq!(
3714            &reject_request.rejections,
3715            &[
3716                EditPredictionRejection {
3717                    request_id: cancelled_id,
3718                    reason: EditPredictionRejectReason::Canceled,
3719                    was_shown: false
3720                },
3721                EditPredictionRejection {
3722                    request_id: first_id,
3723                    reason: EditPredictionRejectReason::Replaced,
3724                    was_shown: false
3725                }
3726            ]
3727        );
3728    }
3729
3730    // Skipped until we start including diagnostics in prompt
3731    // #[gpui::test]
3732    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3733    //     let (zeta, mut req_rx) = init_test(cx);
3734    //     let fs = FakeFs::new(cx.executor());
3735    //     fs.insert_tree(
3736    //         "/root",
3737    //         json!({
3738    //             "foo.md": "Hello!\nBye"
3739    //         }),
3740    //     )
3741    //     .await;
3742    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3743
3744    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3745    //     let diagnostic = lsp::Diagnostic {
3746    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3747    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3748    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3749    //         ..Default::default()
3750    //     };
3751
3752    //     project.update(cx, |project, cx| {
3753    //         project.lsp_store().update(cx, |lsp_store, cx| {
3754    //             // Create some diagnostics
3755    //             lsp_store
3756    //                 .update_diagnostics(
3757    //                     LanguageServerId(0),
3758    //                     lsp::PublishDiagnosticsParams {
3759    //                         uri: path_to_buffer_uri.clone(),
3760    //                         diagnostics: vec![diagnostic],
3761    //                         version: None,
3762    //                     },
3763    //                     None,
3764    //                     language::DiagnosticSourceKind::Pushed,
3765    //                     &[],
3766    //                     cx,
3767    //                 )
3768    //                 .unwrap();
3769    //         });
3770    //     });
3771
3772    //     let buffer = project
3773    //         .update(cx, |project, cx| {
3774    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3775    //             project.open_buffer(path, cx)
3776    //         })
3777    //         .await
3778    //         .unwrap();
3779
3780    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3781    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3782
3783    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3784    //         zeta.request_prediction(&project, &buffer, position, cx)
3785    //     });
3786
3787    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3788
3789    //     assert_eq!(request.diagnostic_groups.len(), 1);
3790    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3791    //         .unwrap();
3792    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3793    //     assert_eq!(
3794    //         value,
3795    //         json!({
3796    //             "entries": [{
3797    //                 "range": {
3798    //                     "start": 8,
3799    //                     "end": 10
3800    //                 },
3801    //                 "diagnostic": {
3802    //                     "source": null,
3803    //                     "code": null,
3804    //                     "code_description": null,
3805    //                     "severity": 1,
3806    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3807    //                     "markdown": null,
3808    //                     "group_id": 0,
3809    //                     "is_primary": true,
3810    //                     "is_disk_based": false,
3811    //                     "is_unnecessary": false,
3812    //                     "source_kind": "Pushed",
3813    //                     "data": null,
3814    //                     "underline": true
3815    //                 }
3816    //             }],
3817    //             "primary_ix": 0
3818    //         })
3819    //     );
3820    // }
3821
3822    fn model_response(text: &str) -> open_ai::Response {
3823        open_ai::Response {
3824            id: Uuid::new_v4().to_string(),
3825            object: "response".into(),
3826            created: 0,
3827            model: "model".into(),
3828            choices: vec![open_ai::Choice {
3829                index: 0,
3830                message: open_ai::RequestMessage::Assistant {
3831                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3832                    tool_calls: vec![],
3833                },
3834                finish_reason: None,
3835            }],
3836            usage: Usage {
3837                prompt_tokens: 0,
3838                completion_tokens: 0,
3839                total_tokens: 0,
3840            },
3841        }
3842    }
3843
3844    fn prompt_from_request(request: &open_ai::Request) -> &str {
3845        assert_eq!(request.messages.len(), 1);
3846        let open_ai::RequestMessage::User {
3847            content: open_ai::MessageContent::Plain(content),
3848            ..
3849        } = &request.messages[0]
3850        else {
3851            panic!(
3852                "Request does not have single user message of type Plain. {:#?}",
3853                request
3854            );
3855        };
3856        content
3857    }
3858
3859    struct RequestChannels {
3860        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3861        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3862    }
3863
3864    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3865        cx.update(move |cx| {
3866            let settings_store = SettingsStore::test(cx);
3867            cx.set_global(settings_store);
3868            zlog::init_test();
3869
3870            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3871            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3872
3873            let http_client = FakeHttpClient::create({
3874                move |req| {
3875                    let uri = req.uri().path().to_string();
3876                    let mut body = req.into_body();
3877                    let predict_req_tx = predict_req_tx.clone();
3878                    let reject_req_tx = reject_req_tx.clone();
3879                    async move {
3880                        let resp = match uri.as_str() {
3881                            "/client/llm_tokens" => serde_json::to_string(&json!({
3882                                "token": "test"
3883                            }))
3884                            .unwrap(),
3885                            "/predict_edits/raw" => {
3886                                let mut buf = Vec::new();
3887                                body.read_to_end(&mut buf).await.ok();
3888                                let req = serde_json::from_slice(&buf).unwrap();
3889
3890                                let (res_tx, res_rx) = oneshot::channel();
3891                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3892                                serde_json::to_string(&res_rx.await?).unwrap()
3893                            }
3894                            "/predict_edits/reject" => {
3895                                let mut buf = Vec::new();
3896                                body.read_to_end(&mut buf).await.ok();
3897                                let req = serde_json::from_slice(&buf).unwrap();
3898
3899                                let (res_tx, res_rx) = oneshot::channel();
3900                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3901                                serde_json::to_string(&res_rx.await?).unwrap()
3902                            }
3903                            _ => {
3904                                panic!("Unexpected path: {}", uri)
3905                            }
3906                        };
3907
3908                        Ok(Response::builder().body(resp.into()).unwrap())
3909                    }
3910                }
3911            });
3912
3913            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3914            client.cloud_client().set_credentials(1, "test".into());
3915
3916            language_model::init(client.clone(), cx);
3917
3918            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3919            let zeta = Zeta::global(&client, &user_store, cx);
3920
3921            (
3922                zeta,
3923                RequestChannels {
3924                    predict: predict_req_rx,
3925                    reject: reject_req_rx,
3926                },
3927            )
3928        })
3929    }
3930}