zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  13use collections::{HashMap, HashSet};
  14use command_palette_hooks::CommandPaletteFilter;
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{
  17    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  18    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  19    SyntaxIndex, SyntaxIndexState,
  20};
  21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  22use futures::channel::{mpsc, oneshot};
  23use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::{
  30    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  31};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, RefreshLlmTokenListener};
  34use open_ai::FunctionDefinition;
  35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  40use std::any::{Any as _, TypeId};
  41use std::collections::{VecDeque, hash_map};
  42use telemetry_events::EditPredictionRating;
  43use workspace::Workspace;
  44
  45use std::ops::Range;
  46use std::path::Path;
  47use std::rc::Rc;
  48use std::str::FromStr as _;
  49use std::sync::{Arc, LazyLock};
  50use std::time::{Duration, Instant};
  51use std::{env, mem};
  52use thiserror::Error;
  53use util::rel_path::RelPathBuf;
  54use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  56
  57pub mod assemble_excerpts;
  58mod license_detection;
  59mod onboarding_modal;
  60mod prediction;
  61mod provider;
  62mod rate_prediction_modal;
  63pub mod retrieval_search;
  64pub mod sweep_ai;
  65pub mod udiff;
  66mod xml_edits;
  67pub mod zeta1;
  68
  69#[cfg(test)]
  70mod zeta_tests;
  71
  72use crate::assemble_excerpts::assemble_excerpts;
  73use crate::license_detection::LicenseDetectionWatcher;
  74use crate::onboarding_modal::ZedPredictModal;
  75pub use crate::prediction::EditPrediction;
  76pub use crate::prediction::EditPredictionId;
  77pub use crate::prediction::EditPredictionInputs;
  78use crate::prediction::EditPredictionResult;
  79use crate::rate_prediction_modal::{
  80    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  81    ThumbsUpActivePrediction,
  82};
  83pub use crate::sweep_ai::SweepAi;
  84use crate::zeta1::request_prediction_with_zeta1;
  85pub use provider::ZetaEditPredictionProvider;
  86
  87actions!(
  88    edit_prediction,
  89    [
  90        /// Resets the edit prediction onboarding state.
  91        ResetOnboarding,
  92        /// Opens the rate completions modal.
  93        RateCompletions,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 103
 104pub struct SweepFeatureFlag;
 105
 106impl FeatureFlag for SweepFeatureFlag {
 107    const NAME: &str = "sweep-ai";
 108}
 109pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 110    max_bytes: 512,
 111    min_bytes: 128,
 112    target_before_cursor_over_total_bytes: 0.5,
 113};
 114
 115pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 116    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 117
 118pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 119    excerpt: DEFAULT_EXCERPT_OPTIONS,
 120};
 121
 122pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 123    EditPredictionContextOptions {
 124        use_imports: true,
 125        max_retrieved_declarations: 0,
 126        excerpt: DEFAULT_EXCERPT_OPTIONS,
 127        score: EditPredictionScoreOptions {
 128            omit_excerpt_overlaps: true,
 129        },
 130    };
 131
 132pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 133    context: DEFAULT_CONTEXT_OPTIONS,
 134    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 135    max_diagnostic_bytes: 2048,
 136    prompt_format: PromptFormat::DEFAULT,
 137    file_indexing_parallelism: 1,
 138    buffer_change_grouping_interval: Duration::from_secs(1),
 139};
 140
 141static USE_OLLAMA: LazyLock<bool> =
 142    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 143static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 144    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 145        "qwen3-coder:30b".to_string()
 146    } else {
 147        "yqvev8r3".to_string()
 148    })
 149});
 150static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 151    match env::var("ZED_ZETA2_MODEL").as_deref() {
 152        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 153        Ok(model) => model,
 154        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 155        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 156    }
 157    .to_string()
 158});
 159static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 160    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 161        if *USE_OLLAMA {
 162            Some("http://localhost:11434/v1/chat/completions".into())
 163        } else {
 164            None
 165        }
 166    })
 167});
 168
 169pub struct Zeta2FeatureFlag;
 170
 171impl FeatureFlag for Zeta2FeatureFlag {
 172    const NAME: &'static str = "zeta2";
 173
 174    fn enabled_for_staff() -> bool {
 175        true
 176    }
 177}
 178
 179#[derive(Clone)]
 180struct ZetaGlobal(Entity<Zeta>);
 181
 182impl Global for ZetaGlobal {}
 183
 184pub struct Zeta {
 185    client: Arc<Client>,
 186    user_store: Entity<UserStore>,
 187    llm_token: LlmApiToken,
 188    _llm_token_subscription: Subscription,
 189    projects: HashMap<EntityId, ZetaProject>,
 190    options: ZetaOptions,
 191    update_required: bool,
 192    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 193    #[cfg(feature = "eval-support")]
 194    eval_cache: Option<Arc<dyn EvalCache>>,
 195    edit_prediction_model: ZetaEditPredictionModel,
 196    pub sweep_ai: SweepAi,
 197    data_collection_choice: DataCollectionChoice,
 198    rejected_predictions: Vec<EditPredictionRejection>,
 199    reject_predictions_tx: mpsc::UnboundedSender<()>,
 200    reject_predictions_debounce_task: Option<Task<()>>,
 201    shown_predictions: VecDeque<EditPrediction>,
 202    rated_predictions: HashSet<EditPredictionId>,
 203}
 204
 205#[derive(Copy, Clone, Default, PartialEq, Eq)]
 206pub enum ZetaEditPredictionModel {
 207    #[default]
 208    Zeta1,
 209    Zeta2,
 210    Sweep,
 211}
 212
 213#[derive(Debug, Clone, PartialEq)]
 214pub struct ZetaOptions {
 215    pub context: ContextMode,
 216    pub max_prompt_bytes: usize,
 217    pub max_diagnostic_bytes: usize,
 218    pub prompt_format: predict_edits_v3::PromptFormat,
 219    pub file_indexing_parallelism: usize,
 220    pub buffer_change_grouping_interval: Duration,
 221}
 222
 223#[derive(Debug, Clone, PartialEq)]
 224pub enum ContextMode {
 225    Agentic(AgenticContextOptions),
 226    Syntax(EditPredictionContextOptions),
 227}
 228
 229#[derive(Debug, Clone, PartialEq)]
 230pub struct AgenticContextOptions {
 231    pub excerpt: EditPredictionExcerptOptions,
 232}
 233
 234impl ContextMode {
 235    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 236        match self {
 237            ContextMode::Agentic(options) => &options.excerpt,
 238            ContextMode::Syntax(options) => &options.excerpt,
 239        }
 240    }
 241}
 242
 243#[derive(Debug)]
 244pub enum ZetaDebugInfo {
 245    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 246    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 247    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 248    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 249    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 250}
 251
 252#[derive(Debug)]
 253pub struct ZetaContextRetrievalStartedDebugInfo {
 254    pub project: Entity<Project>,
 255    pub timestamp: Instant,
 256    pub search_prompt: String,
 257}
 258
 259#[derive(Debug)]
 260pub struct ZetaContextRetrievalDebugInfo {
 261    pub project: Entity<Project>,
 262    pub timestamp: Instant,
 263}
 264
 265#[derive(Debug)]
 266pub struct ZetaEditPredictionDebugInfo {
 267    pub inputs: EditPredictionInputs,
 268    pub retrieval_time: Duration,
 269    pub buffer: WeakEntity<Buffer>,
 270    pub position: language::Anchor,
 271    pub local_prompt: Result<String, String>,
 272    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 273}
 274
 275#[derive(Debug)]
 276pub struct ZetaSearchQueryDebugInfo {
 277    pub project: Entity<Project>,
 278    pub timestamp: Instant,
 279    pub search_queries: Vec<SearchToolQuery>,
 280}
 281
 282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 283
 284struct ZetaProject {
 285    syntax_index: Option<Entity<SyntaxIndex>>,
 286    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 287    last_event: Option<LastEvent>,
 288    recent_paths: VecDeque<ProjectPath>,
 289    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 290    current_prediction: Option<CurrentEditPrediction>,
 291    next_pending_prediction_id: usize,
 292    pending_predictions: ArrayVec<PendingPrediction, 2>,
 293    last_prediction_refresh: Option<(EntityId, Instant)>,
 294    cancelled_predictions: HashSet<usize>,
 295    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 296    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 297    refresh_context_debounce_task: Option<Task<Option<()>>>,
 298    refresh_context_timestamp: Option<Instant>,
 299    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 300    _subscription: gpui::Subscription,
 301}
 302
 303impl ZetaProject {
 304    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 305        self.events
 306            .iter()
 307            .cloned()
 308            .chain(
 309                self.last_event
 310                    .as_ref()
 311                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 312            )
 313            .collect()
 314    }
 315
 316    fn cancel_pending_prediction(
 317        &mut self,
 318        pending_prediction: PendingPrediction,
 319        cx: &mut Context<Zeta>,
 320    ) {
 321        self.cancelled_predictions.insert(pending_prediction.id);
 322
 323        cx.spawn(async move |this, cx| {
 324            let Some(prediction_id) = pending_prediction.task.await else {
 325                return;
 326            };
 327
 328            this.update(cx, |this, cx| {
 329                this.reject_prediction(
 330                    prediction_id,
 331                    EditPredictionRejectReason::Canceled,
 332                    false,
 333                    cx,
 334                );
 335            })
 336            .ok();
 337        })
 338        .detach()
 339    }
 340}
 341
 342#[derive(Debug, Clone)]
 343struct CurrentEditPrediction {
 344    pub requested_by: PredictionRequestedBy,
 345    pub prediction: EditPrediction,
 346    pub was_shown: bool,
 347}
 348
 349impl CurrentEditPrediction {
 350    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 351        let Some(new_edits) = self
 352            .prediction
 353            .interpolate(&self.prediction.buffer.read(cx))
 354        else {
 355            return false;
 356        };
 357
 358        if self.prediction.buffer != old_prediction.prediction.buffer {
 359            return true;
 360        }
 361
 362        let Some(old_edits) = old_prediction
 363            .prediction
 364            .interpolate(&old_prediction.prediction.buffer.read(cx))
 365        else {
 366            return true;
 367        };
 368
 369        let requested_by_buffer_id = self.requested_by.buffer_id();
 370
 371        // This reduces the occurrence of UI thrash from replacing edits
 372        //
 373        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 374        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 375            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 376            && old_edits.len() == 1
 377            && new_edits.len() == 1
 378        {
 379            let (old_range, old_text) = &old_edits[0];
 380            let (new_range, new_text) = &new_edits[0];
 381            new_range == old_range && new_text.starts_with(old_text.as_ref())
 382        } else {
 383            true
 384        }
 385    }
 386}
 387
 388#[derive(Debug, Clone)]
 389enum PredictionRequestedBy {
 390    DiagnosticsUpdate,
 391    Buffer(EntityId),
 392}
 393
 394impl PredictionRequestedBy {
 395    pub fn buffer_id(&self) -> Option<EntityId> {
 396        match self {
 397            PredictionRequestedBy::DiagnosticsUpdate => None,
 398            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 399        }
 400    }
 401}
 402
 403#[derive(Debug)]
 404struct PendingPrediction {
 405    id: usize,
 406    task: Task<Option<EditPredictionId>>,
 407}
 408
 409/// A prediction from the perspective of a buffer.
 410#[derive(Debug)]
 411enum BufferEditPrediction<'a> {
 412    Local { prediction: &'a EditPrediction },
 413    Jump { prediction: &'a EditPrediction },
 414}
 415
 416#[cfg(test)]
 417impl std::ops::Deref for BufferEditPrediction<'_> {
 418    type Target = EditPrediction;
 419
 420    fn deref(&self) -> &Self::Target {
 421        match self {
 422            BufferEditPrediction::Local { prediction } => prediction,
 423            BufferEditPrediction::Jump { prediction } => prediction,
 424        }
 425    }
 426}
 427
 428struct RegisteredBuffer {
 429    snapshot: BufferSnapshot,
 430    _subscriptions: [gpui::Subscription; 2],
 431}
 432
 433struct LastEvent {
 434    old_snapshot: BufferSnapshot,
 435    new_snapshot: BufferSnapshot,
 436    end_edit_anchor: Option<Anchor>,
 437}
 438
 439impl LastEvent {
 440    pub fn finalize(
 441        &self,
 442        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 443        cx: &App,
 444    ) -> Option<Arc<predict_edits_v3::Event>> {
 445        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 446        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 447
 448        let file = self.new_snapshot.file();
 449        let old_file = self.old_snapshot.file();
 450
 451        let in_open_source_repo = [file, old_file].iter().all(|file| {
 452            file.is_some_and(|file| {
 453                license_detection_watchers
 454                    .get(&file.worktree_id(cx))
 455                    .is_some_and(|watcher| watcher.is_project_open_source())
 456            })
 457        });
 458
 459        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 460
 461        if path == old_path && diff.is_empty() {
 462            None
 463        } else {
 464            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 465                old_path,
 466                path,
 467                diff,
 468                in_open_source_repo,
 469                // TODO: Actually detect if this edit was predicted or not
 470                predicted: false,
 471            }))
 472        }
 473    }
 474}
 475
 476fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 477    if let Some(file) = snapshot.file() {
 478        file.full_path(cx).into()
 479    } else {
 480        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 481    }
 482}
 483
 484impl Zeta {
 485    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 486        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 487    }
 488
 489    pub fn global(
 490        client: &Arc<Client>,
 491        user_store: &Entity<UserStore>,
 492        cx: &mut App,
 493    ) -> Entity<Self> {
 494        cx.try_global::<ZetaGlobal>()
 495            .map(|global| global.0.clone())
 496            .unwrap_or_else(|| {
 497                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 498                cx.set_global(ZetaGlobal(zeta.clone()));
 499                zeta
 500            })
 501    }
 502
 503    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 504        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 505        let data_collection_choice = Self::load_data_collection_choice();
 506
 507        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 508        cx.spawn(async move |this, cx| {
 509            while let Some(()) = reject_rx.next().await {
 510                this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
 511                    .await
 512                    .log_err();
 513            }
 514            anyhow::Ok(())
 515        })
 516        .detach();
 517
 518        Self {
 519            projects: HashMap::default(),
 520            client,
 521            user_store,
 522            options: DEFAULT_OPTIONS,
 523            llm_token: LlmApiToken::default(),
 524            _llm_token_subscription: cx.subscribe(
 525                &refresh_llm_token_listener,
 526                |this, _listener, _event, cx| {
 527                    let client = this.client.clone();
 528                    let llm_token = this.llm_token.clone();
 529                    cx.spawn(async move |_this, _cx| {
 530                        llm_token.refresh(&client).await?;
 531                        anyhow::Ok(())
 532                    })
 533                    .detach_and_log_err(cx);
 534                },
 535            ),
 536            update_required: false,
 537            debug_tx: None,
 538            #[cfg(feature = "eval-support")]
 539            eval_cache: None,
 540            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 541            sweep_ai: SweepAi::new(cx),
 542            data_collection_choice,
 543            rejected_predictions: Vec::new(),
 544            reject_predictions_debounce_task: None,
 545            reject_predictions_tx: reject_tx,
 546            rated_predictions: Default::default(),
 547            shown_predictions: Default::default(),
 548        }
 549    }
 550
 551    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 552        self.edit_prediction_model = model;
 553    }
 554
 555    pub fn has_sweep_api_token(&self) -> bool {
 556        self.sweep_ai
 557            .api_token
 558            .clone()
 559            .now_or_never()
 560            .flatten()
 561            .is_some()
 562    }
 563
 564    #[cfg(feature = "eval-support")]
 565    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 566        self.eval_cache = Some(cache);
 567    }
 568
 569    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 570        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 571        self.debug_tx = Some(debug_watch_tx);
 572        debug_watch_rx
 573    }
 574
 575    pub fn options(&self) -> &ZetaOptions {
 576        &self.options
 577    }
 578
 579    pub fn set_options(&mut self, options: ZetaOptions) {
 580        self.options = options;
 581    }
 582
 583    pub fn clear_history(&mut self) {
 584        for zeta_project in self.projects.values_mut() {
 585            zeta_project.events.clear();
 586        }
 587    }
 588
 589    pub fn context_for_project(
 590        &self,
 591        project: &Entity<Project>,
 592    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 593        self.projects
 594            .get(&project.entity_id())
 595            .and_then(|project| {
 596                Some(
 597                    project
 598                        .context
 599                        .as_ref()?
 600                        .iter()
 601                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 602                )
 603            })
 604            .into_iter()
 605            .flatten()
 606    }
 607
 608    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 609        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 610            self.user_store.read(cx).edit_prediction_usage()
 611        } else {
 612            None
 613        }
 614    }
 615
 616    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 617        self.get_or_init_zeta_project(project, cx);
 618    }
 619
 620    pub fn register_buffer(
 621        &mut self,
 622        buffer: &Entity<Buffer>,
 623        project: &Entity<Project>,
 624        cx: &mut Context<Self>,
 625    ) {
 626        let zeta_project = self.get_or_init_zeta_project(project, cx);
 627        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 628    }
 629
 630    fn get_or_init_zeta_project(
 631        &mut self,
 632        project: &Entity<Project>,
 633        cx: &mut Context<Self>,
 634    ) -> &mut ZetaProject {
 635        self.projects
 636            .entry(project.entity_id())
 637            .or_insert_with(|| ZetaProject {
 638                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 639                    Some(cx.new(|cx| {
 640                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 641                    }))
 642                } else {
 643                    None
 644                },
 645                events: VecDeque::new(),
 646                last_event: None,
 647                recent_paths: VecDeque::new(),
 648                registered_buffers: HashMap::default(),
 649                current_prediction: None,
 650                cancelled_predictions: HashSet::default(),
 651                pending_predictions: ArrayVec::new(),
 652                next_pending_prediction_id: 0,
 653                last_prediction_refresh: None,
 654                context: None,
 655                refresh_context_task: None,
 656                refresh_context_debounce_task: None,
 657                refresh_context_timestamp: None,
 658                license_detection_watchers: HashMap::default(),
 659                _subscription: cx.subscribe(&project, Self::handle_project_event),
 660            })
 661    }
 662
 663    fn handle_project_event(
 664        &mut self,
 665        project: Entity<Project>,
 666        event: &project::Event,
 667        cx: &mut Context<Self>,
 668    ) {
 669        // TODO [zeta2] init with recent paths
 670        match event {
 671            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 672                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 673                    return;
 674                };
 675                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 676                if let Some(path) = path {
 677                    if let Some(ix) = zeta_project
 678                        .recent_paths
 679                        .iter()
 680                        .position(|probe| probe == &path)
 681                    {
 682                        zeta_project.recent_paths.remove(ix);
 683                    }
 684                    zeta_project.recent_paths.push_front(path);
 685                }
 686            }
 687            project::Event::DiagnosticsUpdated { .. } => {
 688                if cx.has_flag::<Zeta2FeatureFlag>() {
 689                    self.refresh_prediction_from_diagnostics(project, cx);
 690                }
 691            }
 692            _ => (),
 693        }
 694    }
 695
 696    fn register_buffer_impl<'a>(
 697        zeta_project: &'a mut ZetaProject,
 698        buffer: &Entity<Buffer>,
 699        project: &Entity<Project>,
 700        cx: &mut Context<Self>,
 701    ) -> &'a mut RegisteredBuffer {
 702        let buffer_id = buffer.entity_id();
 703
 704        if let Some(file) = buffer.read(cx).file() {
 705            let worktree_id = file.worktree_id(cx);
 706            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 707                zeta_project
 708                    .license_detection_watchers
 709                    .entry(worktree_id)
 710                    .or_insert_with(|| {
 711                        let project_entity_id = project.entity_id();
 712                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 713                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 714                            else {
 715                                return;
 716                            };
 717                            zeta_project.license_detection_watchers.remove(&worktree_id);
 718                        })
 719                        .detach();
 720                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 721                    });
 722            }
 723        }
 724
 725        match zeta_project.registered_buffers.entry(buffer_id) {
 726            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 727            hash_map::Entry::Vacant(entry) => {
 728                let snapshot = buffer.read(cx).snapshot();
 729                let project_entity_id = project.entity_id();
 730                entry.insert(RegisteredBuffer {
 731                    snapshot,
 732                    _subscriptions: [
 733                        cx.subscribe(buffer, {
 734                            let project = project.downgrade();
 735                            move |this, buffer, event, cx| {
 736                                if let language::BufferEvent::Edited = event
 737                                    && let Some(project) = project.upgrade()
 738                                {
 739                                    this.report_changes_for_buffer(&buffer, &project, cx);
 740                                }
 741                            }
 742                        }),
 743                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 744                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 745                            else {
 746                                return;
 747                            };
 748                            zeta_project.registered_buffers.remove(&buffer_id);
 749                        }),
 750                    ],
 751                })
 752            }
 753        }
 754    }
 755
 756    fn report_changes_for_buffer(
 757        &mut self,
 758        buffer: &Entity<Buffer>,
 759        project: &Entity<Project>,
 760        cx: &mut Context<Self>,
 761    ) {
 762        let project_state = self.get_or_init_zeta_project(project, cx);
 763        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 764
 765        let new_snapshot = buffer.read(cx).snapshot();
 766        if new_snapshot.version == registered_buffer.snapshot.version {
 767            return;
 768        }
 769
 770        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 771        let end_edit_anchor = new_snapshot
 772            .anchored_edits_since::<Point>(&old_snapshot.version)
 773            .last()
 774            .map(|(_, range)| range.end);
 775        let events = &mut project_state.events;
 776
 777        if let Some(LastEvent {
 778            new_snapshot: last_new_snapshot,
 779            end_edit_anchor: last_end_edit_anchor,
 780            ..
 781        }) = project_state.last_event.as_mut()
 782        {
 783            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 784                == last_new_snapshot.remote_id()
 785                && old_snapshot.version == last_new_snapshot.version;
 786
 787            let should_coalesce = is_next_snapshot_of_same_buffer
 788                && end_edit_anchor
 789                    .as_ref()
 790                    .zip(last_end_edit_anchor.as_ref())
 791                    .is_some_and(|(a, b)| {
 792                        let a = a.to_point(&new_snapshot);
 793                        let b = b.to_point(&new_snapshot);
 794                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 795                    });
 796
 797            if should_coalesce {
 798                *last_end_edit_anchor = end_edit_anchor;
 799                *last_new_snapshot = new_snapshot;
 800                return;
 801            }
 802        }
 803
 804        if events.len() + 1 >= EVENT_COUNT_MAX {
 805            events.pop_front();
 806        }
 807
 808        if let Some(event) = project_state.last_event.take() {
 809            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 810        }
 811
 812        project_state.last_event = Some(LastEvent {
 813            old_snapshot,
 814            new_snapshot,
 815            end_edit_anchor,
 816        });
 817    }
 818
 819    fn current_prediction_for_buffer(
 820        &self,
 821        buffer: &Entity<Buffer>,
 822        project: &Entity<Project>,
 823        cx: &App,
 824    ) -> Option<BufferEditPrediction<'_>> {
 825        let project_state = self.projects.get(&project.entity_id())?;
 826
 827        let CurrentEditPrediction {
 828            requested_by,
 829            prediction,
 830            ..
 831        } = project_state.current_prediction.as_ref()?;
 832
 833        if prediction.targets_buffer(buffer.read(cx)) {
 834            Some(BufferEditPrediction::Local { prediction })
 835        } else {
 836            let show_jump = match requested_by {
 837                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 838                    requested_by_buffer_id == &buffer.entity_id()
 839                }
 840                PredictionRequestedBy::DiagnosticsUpdate => true,
 841            };
 842
 843            if show_jump {
 844                Some(BufferEditPrediction::Jump { prediction })
 845            } else {
 846                None
 847            }
 848        }
 849    }
 850
 851    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 852        match self.edit_prediction_model {
 853            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 854            ZetaEditPredictionModel::Sweep => return,
 855        }
 856
 857        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 858            return;
 859        };
 860
 861        let Some(prediction) = project_state.current_prediction.take() else {
 862            return;
 863        };
 864        let request_id = prediction.prediction.id.to_string();
 865        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 866            project_state.cancel_pending_prediction(pending_prediction, cx);
 867        }
 868
 869        let client = self.client.clone();
 870        let llm_token = self.llm_token.clone();
 871        let app_version = AppVersion::global(cx);
 872        cx.spawn(async move |this, cx| {
 873            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 874                http_client::Url::parse(&predict_edits_url)?
 875            } else {
 876                client
 877                    .http_client()
 878                    .build_zed_llm_url("/predict_edits/accept", &[])?
 879            };
 880
 881            let response = cx
 882                .background_spawn(Self::send_api_request::<()>(
 883                    move |builder| {
 884                        let req = builder.uri(url.as_ref()).body(
 885                            serde_json::to_string(&AcceptEditPredictionBody {
 886                                request_id: request_id.clone(),
 887                            })?
 888                            .into(),
 889                        );
 890                        Ok(req?)
 891                    },
 892                    client,
 893                    llm_token,
 894                    app_version,
 895                ))
 896                .await;
 897
 898            Self::handle_api_response(&this, response, cx)?;
 899            anyhow::Ok(())
 900        })
 901        .detach_and_log_err(cx);
 902    }
 903
 904    fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 905        match self.edit_prediction_model {
 906            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 907            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 908        }
 909
 910        let client = self.client.clone();
 911        let llm_token = self.llm_token.clone();
 912        let app_version = AppVersion::global(cx);
 913        let last_rejection = self.rejected_predictions.last().cloned();
 914        let Some(last_rejection) = last_rejection else {
 915            return Task::ready(anyhow::Ok(()));
 916        };
 917
 918        let body = serde_json::to_string(&RejectEditPredictionsBody {
 919            rejections: self.rejected_predictions.clone(),
 920        })
 921        .ok();
 922
 923        cx.spawn(async move |this, cx| {
 924            let url = client
 925                .http_client()
 926                .build_zed_llm_url("/predict_edits/reject", &[])?;
 927
 928            cx.background_spawn(Self::send_api_request::<()>(
 929                move |builder| {
 930                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 931                    Ok(req?)
 932                },
 933                client,
 934                llm_token,
 935                app_version,
 936            ))
 937            .await
 938            .context("Failed to reject edit predictions")?;
 939
 940            this.update(cx, |this, _| {
 941                if let Some(ix) = this
 942                    .rejected_predictions
 943                    .iter()
 944                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 945                {
 946                    this.rejected_predictions.drain(..ix + 1);
 947                }
 948            })
 949        })
 950    }
 951
 952    fn reject_current_prediction(
 953        &mut self,
 954        reason: EditPredictionRejectReason,
 955        project: &Entity<Project>,
 956        cx: &mut Context<Self>,
 957    ) {
 958        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 959            project_state.pending_predictions.clear();
 960            if let Some(prediction) = project_state.current_prediction.take() {
 961                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
 962            }
 963        };
 964    }
 965
 966    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 967        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 968            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 969                if !current_prediction.was_shown {
 970                    current_prediction.was_shown = true;
 971                    self.shown_predictions
 972                        .push_front(current_prediction.prediction.clone());
 973                    if self.shown_predictions.len() > 50 {
 974                        let completion = self.shown_predictions.pop_back().unwrap();
 975                        self.rated_predictions.remove(&completion.id);
 976                    }
 977                }
 978            }
 979        }
 980    }
 981
 982    fn reject_prediction(
 983        &mut self,
 984        prediction_id: EditPredictionId,
 985        reason: EditPredictionRejectReason,
 986        was_shown: bool,
 987        cx: &mut Context<Self>,
 988    ) {
 989        self.rejected_predictions.push(EditPredictionRejection {
 990            request_id: prediction_id.to_string(),
 991            reason,
 992            was_shown,
 993        });
 994
 995        let reached_request_limit =
 996            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2;
 997        let reject_tx = self.reject_predictions_tx.clone();
 998        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 999            const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
1000            if !reached_request_limit {
1001                cx.background_executor()
1002                    .timer(REJECT_REQUEST_DEBOUNCE)
1003                    .await;
1004            }
1005            reject_tx.unbounded_send(()).log_err();
1006        }));
1007    }
1008
1009    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1010        self.projects
1011            .get(&project.entity_id())
1012            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1013    }
1014
1015    pub fn refresh_prediction_from_buffer(
1016        &mut self,
1017        project: Entity<Project>,
1018        buffer: Entity<Buffer>,
1019        position: language::Anchor,
1020        cx: &mut Context<Self>,
1021    ) {
1022        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1023            let Some(request_task) = this
1024                .update(cx, |this, cx| {
1025                    this.request_prediction(
1026                        &project,
1027                        &buffer,
1028                        position,
1029                        PredictEditsRequestTrigger::Other,
1030                        cx,
1031                    )
1032                })
1033                .log_err()
1034            else {
1035                return Task::ready(anyhow::Ok(None));
1036            };
1037
1038            cx.spawn(async move |_cx| {
1039                request_task.await.map(|prediction_result| {
1040                    prediction_result.map(|prediction_result| {
1041                        (
1042                            prediction_result,
1043                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1044                        )
1045                    })
1046                })
1047            })
1048        })
1049    }
1050
1051    pub fn refresh_prediction_from_diagnostics(
1052        &mut self,
1053        project: Entity<Project>,
1054        cx: &mut Context<Self>,
1055    ) {
1056        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1057            return;
1058        };
1059
1060        // Prefer predictions from buffer
1061        if zeta_project.current_prediction.is_some() {
1062            return;
1063        };
1064
1065        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1066            let Some(open_buffer_task) = project
1067                .update(cx, |project, cx| {
1068                    project
1069                        .active_entry()
1070                        .and_then(|entry| project.path_for_entry(entry, cx))
1071                        .map(|path| project.open_buffer(path, cx))
1072                })
1073                .log_err()
1074                .flatten()
1075            else {
1076                return Task::ready(anyhow::Ok(None));
1077            };
1078
1079            cx.spawn(async move |cx| {
1080                let active_buffer = open_buffer_task.await?;
1081                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1082
1083                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1084                    active_buffer,
1085                    &snapshot,
1086                    Default::default(),
1087                    Default::default(),
1088                    &project,
1089                    cx,
1090                )
1091                .await?
1092                else {
1093                    return anyhow::Ok(None);
1094                };
1095
1096                let Some(prediction_result) = this
1097                    .update(cx, |this, cx| {
1098                        this.request_prediction(
1099                            &project,
1100                            &jump_buffer,
1101                            jump_position,
1102                            PredictEditsRequestTrigger::Diagnostics,
1103                            cx,
1104                        )
1105                    })?
1106                    .await?
1107                else {
1108                    return anyhow::Ok(None);
1109                };
1110
1111                this.update(cx, |this, cx| {
1112                    Some((
1113                        if this
1114                            .get_or_init_zeta_project(&project, cx)
1115                            .current_prediction
1116                            .is_none()
1117                        {
1118                            prediction_result
1119                        } else {
1120                            EditPredictionResult {
1121                                id: prediction_result.id,
1122                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1123                            }
1124                        },
1125                        PredictionRequestedBy::DiagnosticsUpdate,
1126                    ))
1127                })
1128            })
1129        });
1130    }
1131
1132    #[cfg(not(test))]
1133    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1134    #[cfg(test)]
1135    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1136
1137    fn queue_prediction_refresh(
1138        &mut self,
1139        project: Entity<Project>,
1140        throttle_entity: EntityId,
1141        cx: &mut Context<Self>,
1142        do_refresh: impl FnOnce(
1143            WeakEntity<Self>,
1144            &mut AsyncApp,
1145        )
1146            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1147        + 'static,
1148    ) {
1149        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1150        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1151        zeta_project.next_pending_prediction_id += 1;
1152        let last_request = zeta_project.last_prediction_refresh;
1153
1154        let task = cx.spawn(async move |this, cx| {
1155            if let Some((last_entity, last_timestamp)) = last_request
1156                && throttle_entity == last_entity
1157                && let Some(timeout) =
1158                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1159            {
1160                cx.background_executor().timer(timeout).await;
1161            }
1162
1163            // If this task was cancelled before the throttle timeout expired,
1164            // do not perform a request.
1165            let mut is_cancelled = true;
1166            this.update(cx, |this, cx| {
1167                let project_state = this.get_or_init_zeta_project(&project, cx);
1168                if !project_state
1169                    .cancelled_predictions
1170                    .remove(&pending_prediction_id)
1171                {
1172                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1173                    is_cancelled = false;
1174                }
1175            })
1176            .ok();
1177            if is_cancelled {
1178                return None;
1179            }
1180
1181            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1182            let new_prediction_id = new_prediction_result
1183                .as_ref()
1184                .map(|(prediction, _)| prediction.id.clone());
1185
1186            // When a prediction completes, remove it from the pending list, and cancel
1187            // any pending predictions that were enqueued before it.
1188            this.update(cx, |this, cx| {
1189                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1190
1191                let is_cancelled = zeta_project
1192                    .cancelled_predictions
1193                    .remove(&pending_prediction_id);
1194
1195                let new_current_prediction = if !is_cancelled
1196                    && let Some((prediction_result, requested_by)) = new_prediction_result
1197                {
1198                    match prediction_result.prediction {
1199                        Ok(prediction) => {
1200                            let new_prediction = CurrentEditPrediction {
1201                                requested_by,
1202                                prediction,
1203                                was_shown: false,
1204                            };
1205
1206                            if let Some(current_prediction) =
1207                                zeta_project.current_prediction.as_ref()
1208                            {
1209                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1210                                {
1211                                    this.reject_current_prediction(
1212                                        EditPredictionRejectReason::Replaced,
1213                                        &project,
1214                                        cx,
1215                                    );
1216
1217                                    Some(new_prediction)
1218                                } else {
1219                                    this.reject_prediction(
1220                                        new_prediction.prediction.id,
1221                                        EditPredictionRejectReason::CurrentPreferred,
1222                                        false,
1223                                        cx,
1224                                    );
1225                                    None
1226                                }
1227                            } else {
1228                                Some(new_prediction)
1229                            }
1230                        }
1231                        Err(reject_reason) => {
1232                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1233                            None
1234                        }
1235                    }
1236                } else {
1237                    None
1238                };
1239
1240                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1241
1242                if let Some(new_prediction) = new_current_prediction {
1243                    zeta_project.current_prediction = Some(new_prediction);
1244                }
1245
1246                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1247                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1248                    if pending_prediction.id == pending_prediction_id {
1249                        pending_predictions.remove(ix);
1250                        for pending_prediction in pending_predictions.drain(0..ix) {
1251                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1252                        }
1253                        break;
1254                    }
1255                }
1256                this.get_or_init_zeta_project(&project, cx)
1257                    .pending_predictions = pending_predictions;
1258                cx.notify();
1259            })
1260            .ok();
1261
1262            new_prediction_id
1263        });
1264
1265        if zeta_project.pending_predictions.len() <= 1 {
1266            zeta_project.pending_predictions.push(PendingPrediction {
1267                id: pending_prediction_id,
1268                task,
1269            });
1270        } else if zeta_project.pending_predictions.len() == 2 {
1271            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1272            zeta_project.pending_predictions.push(PendingPrediction {
1273                id: pending_prediction_id,
1274                task,
1275            });
1276            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1277        }
1278    }
1279
1280    pub fn request_prediction(
1281        &mut self,
1282        project: &Entity<Project>,
1283        active_buffer: &Entity<Buffer>,
1284        position: language::Anchor,
1285        trigger: PredictEditsRequestTrigger,
1286        cx: &mut Context<Self>,
1287    ) -> Task<Result<Option<EditPredictionResult>>> {
1288        self.request_prediction_internal(
1289            project.clone(),
1290            active_buffer.clone(),
1291            position,
1292            trigger,
1293            cx.has_flag::<Zeta2FeatureFlag>(),
1294            cx,
1295        )
1296    }
1297
1298    fn request_prediction_internal(
1299        &mut self,
1300        project: Entity<Project>,
1301        active_buffer: Entity<Buffer>,
1302        position: language::Anchor,
1303        trigger: PredictEditsRequestTrigger,
1304        allow_jump: bool,
1305        cx: &mut Context<Self>,
1306    ) -> Task<Result<Option<EditPredictionResult>>> {
1307        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1308
1309        self.get_or_init_zeta_project(&project, cx);
1310        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1311        let events = zeta_project.events(cx);
1312        let has_events = !events.is_empty();
1313
1314        let snapshot = active_buffer.read(cx).snapshot();
1315        let cursor_point = position.to_point(&snapshot);
1316        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1317        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1318        let diagnostic_search_range =
1319            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1320
1321        let task = match self.edit_prediction_model {
1322            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1323                self,
1324                &project,
1325                &active_buffer,
1326                snapshot.clone(),
1327                position,
1328                events,
1329                trigger,
1330                cx,
1331            ),
1332            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1333                &project,
1334                &active_buffer,
1335                snapshot.clone(),
1336                position,
1337                events,
1338                trigger,
1339                cx,
1340            ),
1341            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1342                &project,
1343                &active_buffer,
1344                snapshot.clone(),
1345                position,
1346                events,
1347                &zeta_project.recent_paths,
1348                diagnostic_search_range.clone(),
1349                cx,
1350            ),
1351        };
1352
1353        cx.spawn(async move |this, cx| {
1354            let prediction = task.await?;
1355
1356            if prediction.is_none() && allow_jump {
1357                let cursor_point = position.to_point(&snapshot);
1358                if has_events
1359                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1360                        active_buffer.clone(),
1361                        &snapshot,
1362                        diagnostic_search_range,
1363                        cursor_point,
1364                        &project,
1365                        cx,
1366                    )
1367                    .await?
1368                {
1369                    return this
1370                        .update(cx, |this, cx| {
1371                            this.request_prediction_internal(
1372                                project,
1373                                jump_buffer,
1374                                jump_position,
1375                                trigger,
1376                                false,
1377                                cx,
1378                            )
1379                        })?
1380                        .await;
1381                }
1382
1383                return anyhow::Ok(None);
1384            }
1385
1386            Ok(prediction)
1387        })
1388    }
1389
1390    async fn next_diagnostic_location(
1391        active_buffer: Entity<Buffer>,
1392        active_buffer_snapshot: &BufferSnapshot,
1393        active_buffer_diagnostic_search_range: Range<Point>,
1394        active_buffer_cursor_point: Point,
1395        project: &Entity<Project>,
1396        cx: &mut AsyncApp,
1397    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1398        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1399        let mut jump_location = active_buffer_snapshot
1400            .diagnostic_groups(None)
1401            .into_iter()
1402            .filter_map(|(_, group)| {
1403                let range = &group.entries[group.primary_ix]
1404                    .range
1405                    .to_point(&active_buffer_snapshot);
1406                if range.overlaps(&active_buffer_diagnostic_search_range) {
1407                    None
1408                } else {
1409                    Some(range.start)
1410                }
1411            })
1412            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1413            .map(|position| {
1414                (
1415                    active_buffer.clone(),
1416                    active_buffer_snapshot.anchor_before(position),
1417                )
1418            });
1419
1420        if jump_location.is_none() {
1421            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1422                let file = buffer.file()?;
1423
1424                Some(ProjectPath {
1425                    worktree_id: file.worktree_id(cx),
1426                    path: file.path().clone(),
1427                })
1428            })?;
1429
1430            let buffer_task = project.update(cx, |project, cx| {
1431                let (path, _, _) = project
1432                    .diagnostic_summaries(false, cx)
1433                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1434                    .max_by_key(|(path, _, _)| {
1435                        // find the buffer with errors that shares most parent directories
1436                        path.path
1437                            .components()
1438                            .zip(
1439                                active_buffer_path
1440                                    .as_ref()
1441                                    .map(|p| p.path.components())
1442                                    .unwrap_or_default(),
1443                            )
1444                            .take_while(|(a, b)| a == b)
1445                            .count()
1446                    })?;
1447
1448                Some(project.open_buffer(path, cx))
1449            })?;
1450
1451            if let Some(buffer_task) = buffer_task {
1452                let closest_buffer = buffer_task.await?;
1453
1454                jump_location = closest_buffer
1455                    .read_with(cx, |buffer, _cx| {
1456                        buffer
1457                            .buffer_diagnostics(None)
1458                            .into_iter()
1459                            .min_by_key(|entry| entry.diagnostic.severity)
1460                            .map(|entry| entry.range.start)
1461                    })?
1462                    .map(|position| (closest_buffer, position));
1463            }
1464        }
1465
1466        anyhow::Ok(jump_location)
1467    }
1468
1469    fn request_prediction_with_zeta2(
1470        &mut self,
1471        project: &Entity<Project>,
1472        active_buffer: &Entity<Buffer>,
1473        active_snapshot: BufferSnapshot,
1474        position: language::Anchor,
1475        events: Vec<Arc<Event>>,
1476        trigger: PredictEditsRequestTrigger,
1477        cx: &mut Context<Self>,
1478    ) -> Task<Result<Option<EditPredictionResult>>> {
1479        let project_state = self.projects.get(&project.entity_id());
1480
1481        let index_state = project_state.and_then(|state| {
1482            state
1483                .syntax_index
1484                .as_ref()
1485                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1486        });
1487        let options = self.options.clone();
1488        let buffer_snapshotted_at = Instant::now();
1489        let Some(excerpt_path) = active_snapshot
1490            .file()
1491            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1492        else {
1493            return Task::ready(Err(anyhow!("No file path for excerpt")));
1494        };
1495        let client = self.client.clone();
1496        let llm_token = self.llm_token.clone();
1497        let app_version = AppVersion::global(cx);
1498        let worktree_snapshots = project
1499            .read(cx)
1500            .worktrees(cx)
1501            .map(|worktree| worktree.read(cx).snapshot())
1502            .collect::<Vec<_>>();
1503        let debug_tx = self.debug_tx.clone();
1504
1505        let diagnostics = active_snapshot.diagnostic_sets().clone();
1506
1507        let file = active_buffer.read(cx).file();
1508        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1509            let mut path = f.worktree.read(cx).absolutize(&f.path);
1510            if path.pop() { Some(path) } else { None }
1511        });
1512
1513        // TODO data collection
1514        let can_collect_data = file
1515            .as_ref()
1516            .map_or(false, |file| self.can_collect_file(project, file, cx));
1517
1518        let empty_context_files = HashMap::default();
1519        let context_files = project_state
1520            .and_then(|project_state| project_state.context.as_ref())
1521            .unwrap_or(&empty_context_files);
1522
1523        #[cfg(feature = "eval-support")]
1524        let parsed_fut = futures::future::join_all(
1525            context_files
1526                .keys()
1527                .map(|buffer| buffer.read(cx).parsing_idle()),
1528        );
1529
1530        let mut included_files = context_files
1531            .iter()
1532            .filter_map(|(buffer_entity, ranges)| {
1533                let buffer = buffer_entity.read(cx);
1534                Some((
1535                    buffer_entity.clone(),
1536                    buffer.snapshot(),
1537                    buffer.file()?.full_path(cx).into(),
1538                    ranges.clone(),
1539                ))
1540            })
1541            .collect::<Vec<_>>();
1542
1543        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1544            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1545        });
1546
1547        #[cfg(feature = "eval-support")]
1548        let eval_cache = self.eval_cache.clone();
1549
1550        let request_task = cx.background_spawn({
1551            let active_buffer = active_buffer.clone();
1552            async move {
1553                #[cfg(feature = "eval-support")]
1554                parsed_fut.await;
1555
1556                let index_state = if let Some(index_state) = index_state {
1557                    Some(index_state.lock_owned().await)
1558                } else {
1559                    None
1560                };
1561
1562                let cursor_offset = position.to_offset(&active_snapshot);
1563                let cursor_point = cursor_offset.to_point(&active_snapshot);
1564
1565                let before_retrieval = Instant::now();
1566
1567                let (diagnostic_groups, diagnostic_groups_truncated) =
1568                    Self::gather_nearby_diagnostics(
1569                        cursor_offset,
1570                        &diagnostics,
1571                        &active_snapshot,
1572                        options.max_diagnostic_bytes,
1573                    );
1574
1575                let cloud_request = match options.context {
1576                    ContextMode::Agentic(context_options) => {
1577                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1578                            cursor_point,
1579                            &active_snapshot,
1580                            &context_options.excerpt,
1581                            index_state.as_deref(),
1582                        ) else {
1583                            return Ok((None, None));
1584                        };
1585
1586                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1587                            ..active_snapshot.anchor_before(excerpt.range.end);
1588
1589                        if let Some(buffer_ix) =
1590                            included_files.iter().position(|(_, snapshot, _, _)| {
1591                                snapshot.remote_id() == active_snapshot.remote_id()
1592                            })
1593                        {
1594                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1595                            ranges.push(excerpt_anchor_range);
1596                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1597                            let last_ix = included_files.len() - 1;
1598                            included_files.swap(buffer_ix, last_ix);
1599                        } else {
1600                            included_files.push((
1601                                active_buffer.clone(),
1602                                active_snapshot.clone(),
1603                                excerpt_path.clone(),
1604                                vec![excerpt_anchor_range],
1605                            ));
1606                        }
1607
1608                        let included_files = included_files
1609                            .iter()
1610                            .map(|(_, snapshot, path, ranges)| {
1611                                let ranges = ranges
1612                                    .iter()
1613                                    .map(|range| {
1614                                        let point_range = range.to_point(&snapshot);
1615                                        Line(point_range.start.row)..Line(point_range.end.row)
1616                                    })
1617                                    .collect::<Vec<_>>();
1618                                let excerpts = assemble_excerpts(&snapshot, ranges);
1619                                predict_edits_v3::IncludedFile {
1620                                    path: path.clone(),
1621                                    max_row: Line(snapshot.max_point().row),
1622                                    excerpts,
1623                                }
1624                            })
1625                            .collect::<Vec<_>>();
1626
1627                        predict_edits_v3::PredictEditsRequest {
1628                            excerpt_path,
1629                            excerpt: String::new(),
1630                            excerpt_line_range: Line(0)..Line(0),
1631                            excerpt_range: 0..0,
1632                            cursor_point: predict_edits_v3::Point {
1633                                line: predict_edits_v3::Line(cursor_point.row),
1634                                column: cursor_point.column,
1635                            },
1636                            included_files,
1637                            referenced_declarations: vec![],
1638                            events,
1639                            can_collect_data,
1640                            diagnostic_groups,
1641                            diagnostic_groups_truncated,
1642                            debug_info: debug_tx.is_some(),
1643                            prompt_max_bytes: Some(options.max_prompt_bytes),
1644                            prompt_format: options.prompt_format,
1645                            // TODO [zeta2]
1646                            signatures: vec![],
1647                            excerpt_parent: None,
1648                            git_info: None,
1649                            trigger,
1650                        }
1651                    }
1652                    ContextMode::Syntax(context_options) => {
1653                        let Some(context) = EditPredictionContext::gather_context(
1654                            cursor_point,
1655                            &active_snapshot,
1656                            parent_abs_path.as_deref(),
1657                            &context_options,
1658                            index_state.as_deref(),
1659                        ) else {
1660                            return Ok((None, None));
1661                        };
1662
1663                        make_syntax_context_cloud_request(
1664                            excerpt_path,
1665                            context,
1666                            events,
1667                            can_collect_data,
1668                            diagnostic_groups,
1669                            diagnostic_groups_truncated,
1670                            None,
1671                            debug_tx.is_some(),
1672                            &worktree_snapshots,
1673                            index_state.as_deref(),
1674                            Some(options.max_prompt_bytes),
1675                            options.prompt_format,
1676                            trigger,
1677                        )
1678                    }
1679                };
1680
1681                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1682
1683                let inputs = EditPredictionInputs {
1684                    included_files: cloud_request.included_files,
1685                    events: cloud_request.events,
1686                    cursor_point: cloud_request.cursor_point,
1687                    cursor_path: cloud_request.excerpt_path,
1688                };
1689
1690                let retrieval_time = Instant::now() - before_retrieval;
1691
1692                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1693                    let (response_tx, response_rx) = oneshot::channel();
1694
1695                    debug_tx
1696                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1697                            ZetaEditPredictionDebugInfo {
1698                                inputs: inputs.clone(),
1699                                retrieval_time,
1700                                buffer: active_buffer.downgrade(),
1701                                local_prompt: match prompt_result.as_ref() {
1702                                    Ok((prompt, _)) => Ok(prompt.clone()),
1703                                    Err(err) => Err(err.to_string()),
1704                                },
1705                                position,
1706                                response_rx,
1707                            },
1708                        ))
1709                        .ok();
1710                    Some(response_tx)
1711                } else {
1712                    None
1713                };
1714
1715                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1716                    if let Some(debug_response_tx) = debug_response_tx {
1717                        debug_response_tx
1718                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1719                            .ok();
1720                    }
1721                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1722                }
1723
1724                let (prompt, _) = prompt_result?;
1725                let generation_params =
1726                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1727                let request = open_ai::Request {
1728                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1729                    messages: vec![open_ai::RequestMessage::User {
1730                        content: open_ai::MessageContent::Plain(prompt),
1731                    }],
1732                    stream: false,
1733                    max_completion_tokens: None,
1734                    stop: generation_params.stop.unwrap_or_default(),
1735                    temperature: generation_params.temperature.unwrap_or(0.7),
1736                    tool_choice: None,
1737                    parallel_tool_calls: None,
1738                    tools: vec![],
1739                    prompt_cache_key: None,
1740                    reasoning_effort: None,
1741                };
1742
1743                log::trace!("Sending edit prediction request");
1744
1745                let before_request = Instant::now();
1746                let response = Self::send_raw_llm_request(
1747                    request,
1748                    client,
1749                    llm_token,
1750                    app_version,
1751                    #[cfg(feature = "eval-support")]
1752                    eval_cache,
1753                    #[cfg(feature = "eval-support")]
1754                    EvalCacheEntryKind::Prediction,
1755                )
1756                .await;
1757                let received_response_at = Instant::now();
1758                let request_time = received_response_at - before_request;
1759
1760                log::trace!("Got edit prediction response");
1761
1762                if let Some(debug_response_tx) = debug_response_tx {
1763                    debug_response_tx
1764                        .send((
1765                            response
1766                                .as_ref()
1767                                .map_err(|err| err.to_string())
1768                                .map(|response| response.0.clone()),
1769                            request_time,
1770                        ))
1771                        .ok();
1772                }
1773
1774                let (res, usage) = response?;
1775                let request_id = EditPredictionId(res.id.clone().into());
1776                let Some(mut output_text) = text_from_response(res) else {
1777                    return Ok((Some((request_id, None)), usage));
1778                };
1779
1780                if output_text.contains(CURSOR_MARKER) {
1781                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1782                    output_text = output_text.replace(CURSOR_MARKER, "");
1783                }
1784
1785                let get_buffer_from_context = |path: &Path| {
1786                    included_files
1787                        .iter()
1788                        .find_map(|(_, buffer, probe_path, ranges)| {
1789                            if probe_path.as_ref() == path {
1790                                Some((buffer, ranges.as_slice()))
1791                            } else {
1792                                None
1793                            }
1794                        })
1795                };
1796
1797                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1798                    PromptFormat::NumLinesUniDiff => {
1799                        // TODO: Implement parsing of multi-file diffs
1800                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1801                    }
1802                    PromptFormat::Minimal
1803                    | PromptFormat::MinimalQwen
1804                    | PromptFormat::SeedCoder1120 => {
1805                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1806                            let edits = vec![];
1807                            (&active_snapshot, edits)
1808                        } else {
1809                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1810                        }
1811                    }
1812                    PromptFormat::OldTextNewText => {
1813                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1814                            .await?
1815                    }
1816                    _ => {
1817                        bail!("unsupported prompt format {}", options.prompt_format)
1818                    }
1819                };
1820
1821                let edited_buffer = included_files
1822                    .iter()
1823                    .find_map(|(buffer, snapshot, _, _)| {
1824                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1825                            Some(buffer.clone())
1826                        } else {
1827                            None
1828                        }
1829                    })
1830                    .context("Failed to find buffer in included_buffers")?;
1831
1832                anyhow::Ok((
1833                    Some((
1834                        request_id,
1835                        Some((
1836                            inputs,
1837                            edited_buffer,
1838                            edited_buffer_snapshot.clone(),
1839                            edits,
1840                            received_response_at,
1841                        )),
1842                    )),
1843                    usage,
1844                ))
1845            }
1846        });
1847
1848        cx.spawn({
1849            async move |this, cx| {
1850                let Some((id, prediction)) =
1851                    Self::handle_api_response(&this, request_task.await, cx)?
1852                else {
1853                    return Ok(None);
1854                };
1855
1856                let Some((
1857                    inputs,
1858                    edited_buffer,
1859                    edited_buffer_snapshot,
1860                    edits,
1861                    received_response_at,
1862                )) = prediction
1863                else {
1864                    return Ok(Some(EditPredictionResult {
1865                        id,
1866                        prediction: Err(EditPredictionRejectReason::Empty),
1867                    }));
1868                };
1869
1870                // TODO telemetry: duration, etc
1871                Ok(Some(
1872                    EditPredictionResult::new(
1873                        id,
1874                        &edited_buffer,
1875                        &edited_buffer_snapshot,
1876                        edits.into(),
1877                        buffer_snapshotted_at,
1878                        received_response_at,
1879                        inputs,
1880                        cx,
1881                    )
1882                    .await,
1883                ))
1884            }
1885        })
1886    }
1887
1888    async fn send_raw_llm_request(
1889        request: open_ai::Request,
1890        client: Arc<Client>,
1891        llm_token: LlmApiToken,
1892        app_version: Version,
1893        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1894        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1895    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1896        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1897            http_client::Url::parse(&predict_edits_url)?
1898        } else {
1899            client
1900                .http_client()
1901                .build_zed_llm_url("/predict_edits/raw", &[])?
1902        };
1903
1904        #[cfg(feature = "eval-support")]
1905        let cache_key = if let Some(cache) = eval_cache {
1906            use collections::FxHasher;
1907            use std::hash::{Hash, Hasher};
1908
1909            let mut hasher = FxHasher::default();
1910            url.hash(&mut hasher);
1911            let request_str = serde_json::to_string_pretty(&request)?;
1912            request_str.hash(&mut hasher);
1913            let hash = hasher.finish();
1914
1915            let key = (eval_cache_kind, hash);
1916            if let Some(response_str) = cache.read(key) {
1917                return Ok((serde_json::from_str(&response_str)?, None));
1918            }
1919
1920            Some((cache, request_str, key))
1921        } else {
1922            None
1923        };
1924
1925        let (response, usage) = Self::send_api_request(
1926            |builder| {
1927                let req = builder
1928                    .uri(url.as_ref())
1929                    .body(serde_json::to_string(&request)?.into());
1930                Ok(req?)
1931            },
1932            client,
1933            llm_token,
1934            app_version,
1935        )
1936        .await?;
1937
1938        #[cfg(feature = "eval-support")]
1939        if let Some((cache, request, key)) = cache_key {
1940            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1941        }
1942
1943        Ok((response, usage))
1944    }
1945
1946    fn handle_api_response<T>(
1947        this: &WeakEntity<Self>,
1948        response: Result<(T, Option<EditPredictionUsage>)>,
1949        cx: &mut gpui::AsyncApp,
1950    ) -> Result<T> {
1951        match response {
1952            Ok((data, usage)) => {
1953                if let Some(usage) = usage {
1954                    this.update(cx, |this, cx| {
1955                        this.user_store.update(cx, |user_store, cx| {
1956                            user_store.update_edit_prediction_usage(usage, cx);
1957                        });
1958                    })
1959                    .ok();
1960                }
1961                Ok(data)
1962            }
1963            Err(err) => {
1964                if err.is::<ZedUpdateRequiredError>() {
1965                    cx.update(|cx| {
1966                        this.update(cx, |this, _cx| {
1967                            this.update_required = true;
1968                        })
1969                        .ok();
1970
1971                        let error_message: SharedString = err.to_string().into();
1972                        show_app_notification(
1973                            NotificationId::unique::<ZedUpdateRequiredError>(),
1974                            cx,
1975                            move |cx| {
1976                                cx.new(|cx| {
1977                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1978                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1979                                })
1980                            },
1981                        );
1982                    })
1983                    .ok();
1984                }
1985                Err(err)
1986            }
1987        }
1988    }
1989
1990    async fn send_api_request<Res>(
1991        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1992        client: Arc<Client>,
1993        llm_token: LlmApiToken,
1994        app_version: Version,
1995    ) -> Result<(Res, Option<EditPredictionUsage>)>
1996    where
1997        Res: DeserializeOwned,
1998    {
1999        let http_client = client.http_client();
2000        let mut token = llm_token.acquire(&client).await?;
2001        let mut did_retry = false;
2002
2003        loop {
2004            let request_builder = http_client::Request::builder().method(Method::POST);
2005
2006            let request = build(
2007                request_builder
2008                    .header("Content-Type", "application/json")
2009                    .header("Authorization", format!("Bearer {}", token))
2010                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2011            )?;
2012
2013            let mut response = http_client.send(request).await?;
2014
2015            if let Some(minimum_required_version) = response
2016                .headers()
2017                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2018                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2019            {
2020                anyhow::ensure!(
2021                    app_version >= minimum_required_version,
2022                    ZedUpdateRequiredError {
2023                        minimum_version: minimum_required_version
2024                    }
2025                );
2026            }
2027
2028            if response.status().is_success() {
2029                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2030
2031                let mut body = Vec::new();
2032                response.body_mut().read_to_end(&mut body).await?;
2033                return Ok((serde_json::from_slice(&body)?, usage));
2034            } else if !did_retry
2035                && response
2036                    .headers()
2037                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2038                    .is_some()
2039            {
2040                did_retry = true;
2041                token = llm_token.refresh(&client).await?;
2042            } else {
2043                let mut body = String::new();
2044                response.body_mut().read_to_string(&mut body).await?;
2045                anyhow::bail!(
2046                    "Request failed with status: {:?}\nBody: {}",
2047                    response.status(),
2048                    body
2049                );
2050            }
2051        }
2052    }
2053
2054    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2055    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2056
2057    // Refresh the related excerpts when the user just beguns editing after
2058    // an idle period, and after they pause editing.
2059    fn refresh_context_if_needed(
2060        &mut self,
2061        project: &Entity<Project>,
2062        buffer: &Entity<language::Buffer>,
2063        cursor_position: language::Anchor,
2064        cx: &mut Context<Self>,
2065    ) {
2066        if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2067            return;
2068        }
2069
2070        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2071            return;
2072        }
2073
2074        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2075            return;
2076        };
2077
2078        let now = Instant::now();
2079        let was_idle = zeta_project
2080            .refresh_context_timestamp
2081            .map_or(true, |timestamp| {
2082                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2083            });
2084        zeta_project.refresh_context_timestamp = Some(now);
2085        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2086            let buffer = buffer.clone();
2087            let project = project.clone();
2088            async move |this, cx| {
2089                if was_idle {
2090                    log::debug!("refetching edit prediction context after idle");
2091                } else {
2092                    cx.background_executor()
2093                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2094                        .await;
2095                    log::debug!("refetching edit prediction context after pause");
2096                }
2097                this.update(cx, |this, cx| {
2098                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2099
2100                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2101                        zeta_project.refresh_context_task = Some(task.log_err());
2102                    };
2103                })
2104                .ok()
2105            }
2106        }));
2107    }
2108
2109    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2110    // and avoid spawning more than one concurrent task.
2111    pub fn refresh_context(
2112        &mut self,
2113        project: Entity<Project>,
2114        buffer: Entity<language::Buffer>,
2115        cursor_position: language::Anchor,
2116        cx: &mut Context<Self>,
2117    ) -> Task<Result<()>> {
2118        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2119            return Task::ready(anyhow::Ok(()));
2120        };
2121
2122        let ContextMode::Agentic(options) = &self.options().context else {
2123            return Task::ready(anyhow::Ok(()));
2124        };
2125
2126        let snapshot = buffer.read(cx).snapshot();
2127        let cursor_point = cursor_position.to_point(&snapshot);
2128        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2129            cursor_point,
2130            &snapshot,
2131            &options.excerpt,
2132            None,
2133        ) else {
2134            return Task::ready(Ok(()));
2135        };
2136
2137        let app_version = AppVersion::global(cx);
2138        let client = self.client.clone();
2139        let llm_token = self.llm_token.clone();
2140        let debug_tx = self.debug_tx.clone();
2141        let current_file_path: Arc<Path> = snapshot
2142            .file()
2143            .map(|f| f.full_path(cx).into())
2144            .unwrap_or_else(|| Path::new("untitled").into());
2145
2146        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2147            predict_edits_v3::PlanContextRetrievalRequest {
2148                excerpt: cursor_excerpt.text(&snapshot).body,
2149                excerpt_path: current_file_path,
2150                excerpt_line_range: cursor_excerpt.line_range,
2151                cursor_file_max_row: Line(snapshot.max_point().row),
2152                events: zeta_project.events(cx),
2153            },
2154        ) {
2155            Ok(prompt) => prompt,
2156            Err(err) => {
2157                return Task::ready(Err(err));
2158            }
2159        };
2160
2161        if let Some(debug_tx) = &debug_tx {
2162            debug_tx
2163                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2164                    ZetaContextRetrievalStartedDebugInfo {
2165                        project: project.clone(),
2166                        timestamp: Instant::now(),
2167                        search_prompt: prompt.clone(),
2168                    },
2169                ))
2170                .ok();
2171        }
2172
2173        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2174            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2175                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2176            );
2177
2178            let description = schema
2179                .get("description")
2180                .and_then(|description| description.as_str())
2181                .unwrap()
2182                .to_string();
2183
2184            (schema.into(), description)
2185        });
2186
2187        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2188
2189        let request = open_ai::Request {
2190            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2191            messages: vec![open_ai::RequestMessage::User {
2192                content: open_ai::MessageContent::Plain(prompt),
2193            }],
2194            stream: false,
2195            max_completion_tokens: None,
2196            stop: Default::default(),
2197            temperature: 0.7,
2198            tool_choice: None,
2199            parallel_tool_calls: None,
2200            tools: vec![open_ai::ToolDefinition::Function {
2201                function: FunctionDefinition {
2202                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2203                    description: Some(tool_description),
2204                    parameters: Some(tool_schema),
2205                },
2206            }],
2207            prompt_cache_key: None,
2208            reasoning_effort: None,
2209        };
2210
2211        #[cfg(feature = "eval-support")]
2212        let eval_cache = self.eval_cache.clone();
2213
2214        cx.spawn(async move |this, cx| {
2215            log::trace!("Sending search planning request");
2216            let response = Self::send_raw_llm_request(
2217                request,
2218                client,
2219                llm_token,
2220                app_version,
2221                #[cfg(feature = "eval-support")]
2222                eval_cache.clone(),
2223                #[cfg(feature = "eval-support")]
2224                EvalCacheEntryKind::Context,
2225            )
2226            .await;
2227            let mut response = Self::handle_api_response(&this, response, cx)?;
2228            log::trace!("Got search planning response");
2229
2230            let choice = response
2231                .choices
2232                .pop()
2233                .context("No choices in retrieval response")?;
2234            let open_ai::RequestMessage::Assistant {
2235                content: _,
2236                tool_calls,
2237            } = choice.message
2238            else {
2239                anyhow::bail!("Retrieval response didn't include an assistant message");
2240            };
2241
2242            let mut queries: Vec<SearchToolQuery> = Vec::new();
2243            for tool_call in tool_calls {
2244                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2245                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2246                    log::warn!(
2247                        "Context retrieval response tried to call an unknown tool: {}",
2248                        function.name
2249                    );
2250
2251                    continue;
2252                }
2253
2254                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2255                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2256                queries.extend(input.queries);
2257            }
2258
2259            if let Some(debug_tx) = &debug_tx {
2260                debug_tx
2261                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2262                        ZetaSearchQueryDebugInfo {
2263                            project: project.clone(),
2264                            timestamp: Instant::now(),
2265                            search_queries: queries.clone(),
2266                        },
2267                    ))
2268                    .ok();
2269            }
2270
2271            log::trace!("Running retrieval search: {queries:#?}");
2272
2273            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2274                queries,
2275                project.clone(),
2276                #[cfg(feature = "eval-support")]
2277                eval_cache,
2278                cx,
2279            )
2280            .await;
2281
2282            log::trace!("Search queries executed");
2283
2284            if let Some(debug_tx) = &debug_tx {
2285                debug_tx
2286                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2287                        ZetaContextRetrievalDebugInfo {
2288                            project: project.clone(),
2289                            timestamp: Instant::now(),
2290                        },
2291                    ))
2292                    .ok();
2293            }
2294
2295            this.update(cx, |this, _cx| {
2296                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2297                    return Ok(());
2298                };
2299                zeta_project.refresh_context_task.take();
2300                if let Some(debug_tx) = &this.debug_tx {
2301                    debug_tx
2302                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2303                            ZetaContextRetrievalDebugInfo {
2304                                project,
2305                                timestamp: Instant::now(),
2306                            },
2307                        ))
2308                        .ok();
2309                }
2310                match related_excerpts_result {
2311                    Ok(excerpts) => {
2312                        zeta_project.context = Some(excerpts);
2313                        Ok(())
2314                    }
2315                    Err(error) => Err(error),
2316                }
2317            })?
2318        })
2319    }
2320
2321    pub fn set_context(
2322        &mut self,
2323        project: Entity<Project>,
2324        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2325    ) {
2326        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2327            zeta_project.context = Some(context);
2328        }
2329    }
2330
2331    fn gather_nearby_diagnostics(
2332        cursor_offset: usize,
2333        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2334        snapshot: &BufferSnapshot,
2335        max_diagnostics_bytes: usize,
2336    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2337        // TODO: Could make this more efficient
2338        let mut diagnostic_groups = Vec::new();
2339        for (language_server_id, diagnostics) in diagnostic_sets {
2340            let mut groups = Vec::new();
2341            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2342            diagnostic_groups.extend(
2343                groups
2344                    .into_iter()
2345                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2346            );
2347        }
2348
2349        // sort by proximity to cursor
2350        diagnostic_groups.sort_by_key(|group| {
2351            let range = &group.entries[group.primary_ix].range;
2352            if range.start >= cursor_offset {
2353                range.start - cursor_offset
2354            } else if cursor_offset >= range.end {
2355                cursor_offset - range.end
2356            } else {
2357                (cursor_offset - range.start).min(range.end - cursor_offset)
2358            }
2359        });
2360
2361        let mut results = Vec::new();
2362        let mut diagnostic_groups_truncated = false;
2363        let mut diagnostics_byte_count = 0;
2364        for group in diagnostic_groups {
2365            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2366            diagnostics_byte_count += raw_value.get().len();
2367            if diagnostics_byte_count > max_diagnostics_bytes {
2368                diagnostic_groups_truncated = true;
2369                break;
2370            }
2371            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2372        }
2373
2374        (results, diagnostic_groups_truncated)
2375    }
2376
2377    // TODO: Dedupe with similar code in request_prediction?
2378    pub fn cloud_request_for_zeta_cli(
2379        &mut self,
2380        project: &Entity<Project>,
2381        buffer: &Entity<Buffer>,
2382        position: language::Anchor,
2383        cx: &mut Context<Self>,
2384    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2385        let project_state = self.projects.get(&project.entity_id());
2386
2387        let index_state = project_state.and_then(|state| {
2388            state
2389                .syntax_index
2390                .as_ref()
2391                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2392        });
2393        let options = self.options.clone();
2394        let snapshot = buffer.read(cx).snapshot();
2395        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2396            return Task::ready(Err(anyhow!("No file path for excerpt")));
2397        };
2398        let worktree_snapshots = project
2399            .read(cx)
2400            .worktrees(cx)
2401            .map(|worktree| worktree.read(cx).snapshot())
2402            .collect::<Vec<_>>();
2403
2404        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2405            let mut path = f.worktree.read(cx).absolutize(&f.path);
2406            if path.pop() { Some(path) } else { None }
2407        });
2408
2409        cx.background_spawn(async move {
2410            let index_state = if let Some(index_state) = index_state {
2411                Some(index_state.lock_owned().await)
2412            } else {
2413                None
2414            };
2415
2416            let cursor_point = position.to_point(&snapshot);
2417
2418            let debug_info = true;
2419            EditPredictionContext::gather_context(
2420                cursor_point,
2421                &snapshot,
2422                parent_abs_path.as_deref(),
2423                match &options.context {
2424                    ContextMode::Agentic(_) => {
2425                        // TODO
2426                        panic!("Llm mode not supported in zeta cli yet");
2427                    }
2428                    ContextMode::Syntax(edit_prediction_context_options) => {
2429                        edit_prediction_context_options
2430                    }
2431                },
2432                index_state.as_deref(),
2433            )
2434            .context("Failed to select excerpt")
2435            .map(|context| {
2436                make_syntax_context_cloud_request(
2437                    excerpt_path.into(),
2438                    context,
2439                    // TODO pass everything
2440                    Vec::new(),
2441                    false,
2442                    Vec::new(),
2443                    false,
2444                    None,
2445                    debug_info,
2446                    &worktree_snapshots,
2447                    index_state.as_deref(),
2448                    Some(options.max_prompt_bytes),
2449                    options.prompt_format,
2450                    PredictEditsRequestTrigger::Other,
2451                )
2452            })
2453        })
2454    }
2455
2456    pub fn wait_for_initial_indexing(
2457        &mut self,
2458        project: &Entity<Project>,
2459        cx: &mut Context<Self>,
2460    ) -> Task<Result<()>> {
2461        let zeta_project = self.get_or_init_zeta_project(project, cx);
2462        if let Some(syntax_index) = &zeta_project.syntax_index {
2463            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2464        } else {
2465            Task::ready(Ok(()))
2466        }
2467    }
2468
2469    fn is_file_open_source(
2470        &self,
2471        project: &Entity<Project>,
2472        file: &Arc<dyn File>,
2473        cx: &App,
2474    ) -> bool {
2475        if !file.is_local() || file.is_private() {
2476            return false;
2477        }
2478        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2479            return false;
2480        };
2481        zeta_project
2482            .license_detection_watchers
2483            .get(&file.worktree_id(cx))
2484            .as_ref()
2485            .is_some_and(|watcher| watcher.is_project_open_source())
2486    }
2487
2488    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2489        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2490    }
2491
2492    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2493        if !self.data_collection_choice.is_enabled() {
2494            return false;
2495        }
2496        events.iter().all(|event| {
2497            matches!(
2498                event.as_ref(),
2499                Event::BufferChange {
2500                    in_open_source_repo: true,
2501                    ..
2502                }
2503            )
2504        })
2505    }
2506
2507    fn load_data_collection_choice() -> DataCollectionChoice {
2508        let choice = KEY_VALUE_STORE
2509            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2510            .log_err()
2511            .flatten();
2512
2513        match choice.as_deref() {
2514            Some("true") => DataCollectionChoice::Enabled,
2515            Some("false") => DataCollectionChoice::Disabled,
2516            Some(_) => {
2517                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2518                DataCollectionChoice::NotAnswered
2519            }
2520            None => DataCollectionChoice::NotAnswered,
2521        }
2522    }
2523
2524    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2525        self.shown_predictions.iter()
2526    }
2527
2528    pub fn shown_completions_len(&self) -> usize {
2529        self.shown_predictions.len()
2530    }
2531
2532    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2533        self.rated_predictions.contains(id)
2534    }
2535
2536    pub fn rate_prediction(
2537        &mut self,
2538        prediction: &EditPrediction,
2539        rating: EditPredictionRating,
2540        feedback: String,
2541        cx: &mut Context<Self>,
2542    ) {
2543        self.rated_predictions.insert(prediction.id.clone());
2544        telemetry::event!(
2545            "Edit Prediction Rated",
2546            rating,
2547            inputs = prediction.inputs,
2548            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2549            feedback
2550        );
2551        self.client.telemetry().flush_events().detach();
2552        cx.notify();
2553    }
2554}
2555
2556pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2557    let choice = res.choices.pop()?;
2558    let output_text = match choice.message {
2559        open_ai::RequestMessage::Assistant {
2560            content: Some(open_ai::MessageContent::Plain(content)),
2561            ..
2562        } => content,
2563        open_ai::RequestMessage::Assistant {
2564            content: Some(open_ai::MessageContent::Multipart(mut content)),
2565            ..
2566        } => {
2567            if content.is_empty() {
2568                log::error!("No output from Baseten completion response");
2569                return None;
2570            }
2571
2572            match content.remove(0) {
2573                open_ai::MessagePart::Text { text } => text,
2574                open_ai::MessagePart::Image { .. } => {
2575                    log::error!("Expected text, got an image");
2576                    return None;
2577                }
2578            }
2579        }
2580        _ => {
2581            log::error!("Invalid response message: {:?}", choice.message);
2582            return None;
2583        }
2584    };
2585    Some(output_text)
2586}
2587
2588#[derive(Error, Debug)]
2589#[error(
2590    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2591)]
2592pub struct ZedUpdateRequiredError {
2593    minimum_version: Version,
2594}
2595
2596fn make_syntax_context_cloud_request(
2597    excerpt_path: Arc<Path>,
2598    context: EditPredictionContext,
2599    events: Vec<Arc<predict_edits_v3::Event>>,
2600    can_collect_data: bool,
2601    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2602    diagnostic_groups_truncated: bool,
2603    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2604    debug_info: bool,
2605    worktrees: &Vec<worktree::Snapshot>,
2606    index_state: Option<&SyntaxIndexState>,
2607    prompt_max_bytes: Option<usize>,
2608    prompt_format: PromptFormat,
2609    trigger: PredictEditsRequestTrigger,
2610) -> predict_edits_v3::PredictEditsRequest {
2611    let mut signatures = Vec::new();
2612    let mut declaration_to_signature_index = HashMap::default();
2613    let mut referenced_declarations = Vec::new();
2614
2615    for snippet in context.declarations {
2616        let project_entry_id = snippet.declaration.project_entry_id();
2617        let Some(path) = worktrees.iter().find_map(|worktree| {
2618            worktree.entry_for_id(project_entry_id).map(|entry| {
2619                let mut full_path = RelPathBuf::new();
2620                full_path.push(worktree.root_name());
2621                full_path.push(&entry.path);
2622                full_path
2623            })
2624        }) else {
2625            continue;
2626        };
2627
2628        let parent_index = index_state.and_then(|index_state| {
2629            snippet.declaration.parent().and_then(|parent| {
2630                add_signature(
2631                    parent,
2632                    &mut declaration_to_signature_index,
2633                    &mut signatures,
2634                    index_state,
2635                )
2636            })
2637        });
2638
2639        let (text, text_is_truncated) = snippet.declaration.item_text();
2640        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2641            path: path.as_std_path().into(),
2642            text: text.into(),
2643            range: snippet.declaration.item_line_range(),
2644            text_is_truncated,
2645            signature_range: snippet.declaration.signature_range_in_item_text(),
2646            parent_index,
2647            signature_score: snippet.score(DeclarationStyle::Signature),
2648            declaration_score: snippet.score(DeclarationStyle::Declaration),
2649            score_components: snippet.components,
2650        });
2651    }
2652
2653    let excerpt_parent = index_state.and_then(|index_state| {
2654        context
2655            .excerpt
2656            .parent_declarations
2657            .last()
2658            .and_then(|(parent, _)| {
2659                add_signature(
2660                    *parent,
2661                    &mut declaration_to_signature_index,
2662                    &mut signatures,
2663                    index_state,
2664                )
2665            })
2666    });
2667
2668    predict_edits_v3::PredictEditsRequest {
2669        excerpt_path,
2670        excerpt: context.excerpt_text.body,
2671        excerpt_line_range: context.excerpt.line_range,
2672        excerpt_range: context.excerpt.range,
2673        cursor_point: predict_edits_v3::Point {
2674            line: predict_edits_v3::Line(context.cursor_point.row),
2675            column: context.cursor_point.column,
2676        },
2677        referenced_declarations,
2678        included_files: vec![],
2679        signatures,
2680        excerpt_parent,
2681        events,
2682        can_collect_data,
2683        diagnostic_groups,
2684        diagnostic_groups_truncated,
2685        git_info,
2686        debug_info,
2687        prompt_max_bytes,
2688        prompt_format,
2689        trigger,
2690    }
2691}
2692
2693fn add_signature(
2694    declaration_id: DeclarationId,
2695    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2696    signatures: &mut Vec<Signature>,
2697    index: &SyntaxIndexState,
2698) -> Option<usize> {
2699    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2700        return Some(*signature_index);
2701    }
2702    let Some(parent_declaration) = index.declaration(declaration_id) else {
2703        log::error!("bug: missing parent declaration");
2704        return None;
2705    };
2706    let parent_index = parent_declaration.parent().and_then(|parent| {
2707        add_signature(parent, declaration_to_signature_index, signatures, index)
2708    });
2709    let (text, text_is_truncated) = parent_declaration.signature_text();
2710    let signature_index = signatures.len();
2711    signatures.push(Signature {
2712        text: text.into(),
2713        text_is_truncated,
2714        parent_index,
2715        range: parent_declaration.signature_line_range(),
2716    });
2717    declaration_to_signature_index.insert(declaration_id, signature_index);
2718    Some(signature_index)
2719}
2720
2721#[cfg(feature = "eval-support")]
2722pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2723
2724#[cfg(feature = "eval-support")]
2725#[derive(Debug, Clone, Copy, PartialEq)]
2726pub enum EvalCacheEntryKind {
2727    Context,
2728    Search,
2729    Prediction,
2730}
2731
2732#[cfg(feature = "eval-support")]
2733impl std::fmt::Display for EvalCacheEntryKind {
2734    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2735        match self {
2736            EvalCacheEntryKind::Search => write!(f, "search"),
2737            EvalCacheEntryKind::Context => write!(f, "context"),
2738            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2739        }
2740    }
2741}
2742
2743#[cfg(feature = "eval-support")]
2744pub trait EvalCache: Send + Sync {
2745    fn read(&self, key: EvalCacheKey) -> Option<String>;
2746    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2747}
2748
2749#[derive(Debug, Clone, Copy)]
2750pub enum DataCollectionChoice {
2751    NotAnswered,
2752    Enabled,
2753    Disabled,
2754}
2755
2756impl DataCollectionChoice {
2757    pub fn is_enabled(self) -> bool {
2758        match self {
2759            Self::Enabled => true,
2760            Self::NotAnswered | Self::Disabled => false,
2761        }
2762    }
2763
2764    pub fn is_answered(self) -> bool {
2765        match self {
2766            Self::Enabled | Self::Disabled => true,
2767            Self::NotAnswered => false,
2768        }
2769    }
2770
2771    #[must_use]
2772    pub fn toggle(&self) -> DataCollectionChoice {
2773        match self {
2774            Self::Enabled => Self::Disabled,
2775            Self::Disabled => Self::Enabled,
2776            Self::NotAnswered => Self::Enabled,
2777        }
2778    }
2779}
2780
2781impl From<bool> for DataCollectionChoice {
2782    fn from(value: bool) -> Self {
2783        match value {
2784            true => DataCollectionChoice::Enabled,
2785            false => DataCollectionChoice::Disabled,
2786        }
2787    }
2788}
2789
2790struct ZedPredictUpsell;
2791
2792impl Dismissable for ZedPredictUpsell {
2793    const KEY: &'static str = "dismissed-edit-predict-upsell";
2794
2795    fn dismissed() -> bool {
2796        // To make this backwards compatible with older versions of Zed, we
2797        // check if the user has seen the previous Edit Prediction Onboarding
2798        // before, by checking the data collection choice which was written to
2799        // the database once the user clicked on "Accept and Enable"
2800        if KEY_VALUE_STORE
2801            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2802            .log_err()
2803            .is_some_and(|s| s.is_some())
2804        {
2805            return true;
2806        }
2807
2808        KEY_VALUE_STORE
2809            .read_kvp(Self::KEY)
2810            .log_err()
2811            .is_some_and(|s| s.is_some())
2812    }
2813}
2814
2815pub fn should_show_upsell_modal() -> bool {
2816    !ZedPredictUpsell::dismissed()
2817}
2818
2819pub fn init(cx: &mut App) {
2820    feature_gate_predict_edits_actions(cx);
2821
2822    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2823        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2824            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2825                RatePredictionsModal::toggle(workspace, window, cx);
2826            }
2827        });
2828
2829        workspace.register_action(
2830            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2831                ZedPredictModal::toggle(
2832                    workspace,
2833                    workspace.user_store().clone(),
2834                    workspace.client().clone(),
2835                    window,
2836                    cx,
2837                )
2838            },
2839        );
2840
2841        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2842            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2843                settings
2844                    .project
2845                    .all_languages
2846                    .features
2847                    .get_or_insert_default()
2848                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2849            });
2850        });
2851    })
2852    .detach();
2853}
2854
2855fn feature_gate_predict_edits_actions(cx: &mut App) {
2856    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2857    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2858    let zeta_all_action_types = [
2859        TypeId::of::<RateCompletions>(),
2860        TypeId::of::<ResetOnboarding>(),
2861        zed_actions::OpenZedPredictOnboarding.type_id(),
2862        TypeId::of::<ClearHistory>(),
2863        TypeId::of::<ThumbsUpActivePrediction>(),
2864        TypeId::of::<ThumbsDownActivePrediction>(),
2865        TypeId::of::<NextEdit>(),
2866        TypeId::of::<PreviousEdit>(),
2867    ];
2868
2869    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2870        filter.hide_action_types(&rate_completion_action_types);
2871        filter.hide_action_types(&reset_onboarding_action_types);
2872        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2873    });
2874
2875    cx.observe_global::<SettingsStore>(move |cx| {
2876        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2877        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2878
2879        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2880            if is_ai_disabled {
2881                filter.hide_action_types(&zeta_all_action_types);
2882            } else if has_feature_flag {
2883                filter.show_action_types(&rate_completion_action_types);
2884            } else {
2885                filter.hide_action_types(&rate_completion_action_types);
2886            }
2887        });
2888    })
2889    .detach();
2890
2891    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2892        if !DisableAiSettings::get_global(cx).disable_ai {
2893            if is_enabled {
2894                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2895                    filter.show_action_types(&rate_completion_action_types);
2896                });
2897            } else {
2898                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2899                    filter.hide_action_types(&rate_completion_action_types);
2900                });
2901            }
2902        }
2903    })
2904    .detach();
2905}
2906
2907#[cfg(test)]
2908mod tests {
2909    use std::{path::Path, sync::Arc};
2910
2911    use client::UserStore;
2912    use clock::FakeSystemClock;
2913    use cloud_llm_client::{
2914        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2915    };
2916    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2917    use futures::{
2918        AsyncReadExt, StreamExt,
2919        channel::{mpsc, oneshot},
2920    };
2921    use gpui::{
2922        Entity, TestAppContext,
2923        http_client::{FakeHttpClient, Response},
2924        prelude::*,
2925    };
2926    use indoc::indoc;
2927    use language::OffsetRangeExt as _;
2928    use open_ai::Usage;
2929    use pretty_assertions::{assert_eq, assert_matches};
2930    use project::{FakeFs, Project};
2931    use serde_json::json;
2932    use settings::SettingsStore;
2933    use util::path;
2934    use uuid::Uuid;
2935
2936    use crate::{BufferEditPrediction, Zeta};
2937
2938    #[gpui::test]
2939    async fn test_current_state(cx: &mut TestAppContext) {
2940        let (zeta, mut requests) = init_test(cx);
2941        let fs = FakeFs::new(cx.executor());
2942        fs.insert_tree(
2943            "/root",
2944            json!({
2945                "1.txt": "Hello!\nHow\nBye\n",
2946                "2.txt": "Hola!\nComo\nAdios\n"
2947            }),
2948        )
2949        .await;
2950        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2951
2952        zeta.update(cx, |zeta, cx| {
2953            zeta.register_project(&project, cx);
2954        });
2955
2956        let buffer1 = project
2957            .update(cx, |project, cx| {
2958                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2959                project.open_buffer(path, cx)
2960            })
2961            .await
2962            .unwrap();
2963        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2964        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2965
2966        // Prediction for current file
2967
2968        zeta.update(cx, |zeta, cx| {
2969            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2970        });
2971        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2972
2973        respond_tx
2974            .send(model_response(indoc! {r"
2975                --- a/root/1.txt
2976                +++ b/root/1.txt
2977                @@ ... @@
2978                 Hello!
2979                -How
2980                +How are you?
2981                 Bye
2982            "}))
2983            .unwrap();
2984
2985        cx.run_until_parked();
2986
2987        zeta.read_with(cx, |zeta, cx| {
2988            let prediction = zeta
2989                .current_prediction_for_buffer(&buffer1, &project, cx)
2990                .unwrap();
2991            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2992        });
2993
2994        // Context refresh
2995        let refresh_task = zeta.update(cx, |zeta, cx| {
2996            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2997        });
2998        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2999        respond_tx
3000            .send(open_ai::Response {
3001                id: Uuid::new_v4().to_string(),
3002                object: "response".into(),
3003                created: 0,
3004                model: "model".into(),
3005                choices: vec![open_ai::Choice {
3006                    index: 0,
3007                    message: open_ai::RequestMessage::Assistant {
3008                        content: None,
3009                        tool_calls: vec![open_ai::ToolCall {
3010                            id: "search".into(),
3011                            content: open_ai::ToolCallContent::Function {
3012                                function: open_ai::FunctionContent {
3013                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3014                                        .to_string(),
3015                                    arguments: serde_json::to_string(&SearchToolInput {
3016                                        queries: Box::new([SearchToolQuery {
3017                                            glob: "root/2.txt".to_string(),
3018                                            syntax_node: vec![],
3019                                            content: Some(".".into()),
3020                                        }]),
3021                                    })
3022                                    .unwrap(),
3023                                },
3024                            },
3025                        }],
3026                    },
3027                    finish_reason: None,
3028                }],
3029                usage: Usage {
3030                    prompt_tokens: 0,
3031                    completion_tokens: 0,
3032                    total_tokens: 0,
3033                },
3034            })
3035            .unwrap();
3036        refresh_task.await.unwrap();
3037
3038        zeta.update(cx, |zeta, cx| {
3039            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3040        });
3041
3042        // Prediction for another file
3043        zeta.update(cx, |zeta, cx| {
3044            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3045        });
3046        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3047        respond_tx
3048            .send(model_response(indoc! {r#"
3049                --- a/root/2.txt
3050                +++ b/root/2.txt
3051                 Hola!
3052                -Como
3053                +Como estas?
3054                 Adios
3055            "#}))
3056            .unwrap();
3057        cx.run_until_parked();
3058
3059        zeta.read_with(cx, |zeta, cx| {
3060            let prediction = zeta
3061                .current_prediction_for_buffer(&buffer1, &project, cx)
3062                .unwrap();
3063            assert_matches!(
3064                prediction,
3065                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3066            );
3067        });
3068
3069        let buffer2 = project
3070            .update(cx, |project, cx| {
3071                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3072                project.open_buffer(path, cx)
3073            })
3074            .await
3075            .unwrap();
3076
3077        zeta.read_with(cx, |zeta, cx| {
3078            let prediction = zeta
3079                .current_prediction_for_buffer(&buffer2, &project, cx)
3080                .unwrap();
3081            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3082        });
3083    }
3084
3085    #[gpui::test]
3086    async fn test_simple_request(cx: &mut TestAppContext) {
3087        let (zeta, mut requests) = init_test(cx);
3088        let fs = FakeFs::new(cx.executor());
3089        fs.insert_tree(
3090            "/root",
3091            json!({
3092                "foo.md":  "Hello!\nHow\nBye\n"
3093            }),
3094        )
3095        .await;
3096        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3097
3098        let buffer = project
3099            .update(cx, |project, cx| {
3100                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3101                project.open_buffer(path, cx)
3102            })
3103            .await
3104            .unwrap();
3105        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3106        let position = snapshot.anchor_before(language::Point::new(1, 3));
3107
3108        let prediction_task = zeta.update(cx, |zeta, cx| {
3109            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3110        });
3111
3112        let (_, respond_tx) = requests.predict.next().await.unwrap();
3113
3114        // TODO Put back when we have a structured request again
3115        // assert_eq!(
3116        //     request.excerpt_path.as_ref(),
3117        //     Path::new(path!("root/foo.md"))
3118        // );
3119        // assert_eq!(
3120        //     request.cursor_point,
3121        //     Point {
3122        //         line: Line(1),
3123        //         column: 3
3124        //     }
3125        // );
3126
3127        respond_tx
3128            .send(model_response(indoc! { r"
3129                --- a/root/foo.md
3130                +++ b/root/foo.md
3131                @@ ... @@
3132                 Hello!
3133                -How
3134                +How are you?
3135                 Bye
3136            "}))
3137            .unwrap();
3138
3139        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3140
3141        assert_eq!(prediction.edits.len(), 1);
3142        assert_eq!(
3143            prediction.edits[0].0.to_point(&snapshot).start,
3144            language::Point::new(1, 3)
3145        );
3146        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3147    }
3148
3149    #[gpui::test]
3150    async fn test_request_events(cx: &mut TestAppContext) {
3151        let (zeta, mut requests) = init_test(cx);
3152        let fs = FakeFs::new(cx.executor());
3153        fs.insert_tree(
3154            "/root",
3155            json!({
3156                "foo.md": "Hello!\n\nBye\n"
3157            }),
3158        )
3159        .await;
3160        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3161
3162        let buffer = project
3163            .update(cx, |project, cx| {
3164                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3165                project.open_buffer(path, cx)
3166            })
3167            .await
3168            .unwrap();
3169
3170        zeta.update(cx, |zeta, cx| {
3171            zeta.register_buffer(&buffer, &project, cx);
3172        });
3173
3174        buffer.update(cx, |buffer, cx| {
3175            buffer.edit(vec![(7..7, "How")], None, cx);
3176        });
3177
3178        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3179        let position = snapshot.anchor_before(language::Point::new(1, 3));
3180
3181        let prediction_task = zeta.update(cx, |zeta, cx| {
3182            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3183        });
3184
3185        let (request, respond_tx) = requests.predict.next().await.unwrap();
3186
3187        let prompt = prompt_from_request(&request);
3188        assert!(
3189            prompt.contains(indoc! {"
3190            --- a/root/foo.md
3191            +++ b/root/foo.md
3192            @@ -1,3 +1,3 @@
3193             Hello!
3194            -
3195            +How
3196             Bye
3197        "}),
3198            "{prompt}"
3199        );
3200
3201        respond_tx
3202            .send(model_response(indoc! {r#"
3203                --- a/root/foo.md
3204                +++ b/root/foo.md
3205                @@ ... @@
3206                 Hello!
3207                -How
3208                +How are you?
3209                 Bye
3210            "#}))
3211            .unwrap();
3212
3213        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3214
3215        assert_eq!(prediction.edits.len(), 1);
3216        assert_eq!(
3217            prediction.edits[0].0.to_point(&snapshot).start,
3218            language::Point::new(1, 3)
3219        );
3220        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3221    }
3222
3223    #[gpui::test]
3224    async fn test_empty_prediction(cx: &mut TestAppContext) {
3225        let (zeta, mut requests) = init_test(cx);
3226        let fs = FakeFs::new(cx.executor());
3227        fs.insert_tree(
3228            "/root",
3229            json!({
3230                "foo.md":  "Hello!\nHow\nBye\n"
3231            }),
3232        )
3233        .await;
3234        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3235
3236        let buffer = project
3237            .update(cx, |project, cx| {
3238                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3239                project.open_buffer(path, cx)
3240            })
3241            .await
3242            .unwrap();
3243        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3244        let position = snapshot.anchor_before(language::Point::new(1, 3));
3245
3246        zeta.update(cx, |zeta, cx| {
3247            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3248        });
3249
3250        const NO_OP_DIFF: &str = indoc! { r"
3251            --- a/root/foo.md
3252            +++ b/root/foo.md
3253            @@ ... @@
3254             Hello!
3255            -How
3256            +How
3257             Bye
3258        "};
3259
3260        let (_, respond_tx) = requests.predict.next().await.unwrap();
3261        let response = model_response(NO_OP_DIFF);
3262        let id = response.id.clone();
3263        respond_tx.send(response).unwrap();
3264
3265        cx.run_until_parked();
3266
3267        zeta.read_with(cx, |zeta, cx| {
3268            assert!(
3269                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3270                    .is_none()
3271            );
3272        });
3273
3274        // prediction is reported as rejected
3275        let (reject_request, _) = requests.reject.next().await.unwrap();
3276
3277        assert_eq!(
3278            &reject_request.rejections,
3279            &[EditPredictionRejection {
3280                request_id: id,
3281                reason: EditPredictionRejectReason::Empty,
3282                was_shown: false
3283            }]
3284        );
3285    }
3286
3287    #[gpui::test]
3288    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3289        let (zeta, mut requests) = init_test(cx);
3290        let fs = FakeFs::new(cx.executor());
3291        fs.insert_tree(
3292            "/root",
3293            json!({
3294                "foo.md":  "Hello!\nHow\nBye\n"
3295            }),
3296        )
3297        .await;
3298        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3299
3300        let buffer = project
3301            .update(cx, |project, cx| {
3302                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3303                project.open_buffer(path, cx)
3304            })
3305            .await
3306            .unwrap();
3307        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3308        let position = snapshot.anchor_before(language::Point::new(1, 3));
3309
3310        zeta.update(cx, |zeta, cx| {
3311            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3312        });
3313
3314        let (_, respond_tx) = requests.predict.next().await.unwrap();
3315
3316        buffer.update(cx, |buffer, cx| {
3317            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3318        });
3319
3320        let response = model_response(SIMPLE_DIFF);
3321        let id = response.id.clone();
3322        respond_tx.send(response).unwrap();
3323
3324        cx.run_until_parked();
3325
3326        zeta.read_with(cx, |zeta, cx| {
3327            assert!(
3328                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3329                    .is_none()
3330            );
3331        });
3332
3333        // prediction is reported as rejected
3334        let (reject_request, _) = requests.reject.next().await.unwrap();
3335
3336        assert_eq!(
3337            &reject_request.rejections,
3338            &[EditPredictionRejection {
3339                request_id: id,
3340                reason: EditPredictionRejectReason::InterpolatedEmpty,
3341                was_shown: false
3342            }]
3343        );
3344    }
3345
3346    const SIMPLE_DIFF: &str = indoc! { r"
3347        --- a/root/foo.md
3348        +++ b/root/foo.md
3349        @@ ... @@
3350         Hello!
3351        -How
3352        +How are you?
3353         Bye
3354    "};
3355
3356    #[gpui::test]
3357    async fn test_replace_current(cx: &mut TestAppContext) {
3358        let (zeta, mut requests) = init_test(cx);
3359        let fs = FakeFs::new(cx.executor());
3360        fs.insert_tree(
3361            "/root",
3362            json!({
3363                "foo.md":  "Hello!\nHow\nBye\n"
3364            }),
3365        )
3366        .await;
3367        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3368
3369        let buffer = project
3370            .update(cx, |project, cx| {
3371                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3372                project.open_buffer(path, cx)
3373            })
3374            .await
3375            .unwrap();
3376        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3377        let position = snapshot.anchor_before(language::Point::new(1, 3));
3378
3379        zeta.update(cx, |zeta, cx| {
3380            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3381        });
3382
3383        let (_, respond_tx) = requests.predict.next().await.unwrap();
3384        let first_response = model_response(SIMPLE_DIFF);
3385        let first_id = first_response.id.clone();
3386        respond_tx.send(first_response).unwrap();
3387
3388        cx.run_until_parked();
3389
3390        zeta.read_with(cx, |zeta, cx| {
3391            assert_eq!(
3392                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3393                    .unwrap()
3394                    .id
3395                    .0,
3396                first_id
3397            );
3398        });
3399
3400        // a second request is triggered
3401        zeta.update(cx, |zeta, cx| {
3402            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3403        });
3404
3405        let (_, respond_tx) = requests.predict.next().await.unwrap();
3406        let second_response = model_response(SIMPLE_DIFF);
3407        let second_id = second_response.id.clone();
3408        respond_tx.send(second_response).unwrap();
3409
3410        cx.run_until_parked();
3411
3412        zeta.read_with(cx, |zeta, cx| {
3413            // second replaces first
3414            assert_eq!(
3415                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3416                    .unwrap()
3417                    .id
3418                    .0,
3419                second_id
3420            );
3421        });
3422
3423        // first is reported as replaced
3424        let (reject_request, _) = requests.reject.next().await.unwrap();
3425
3426        assert_eq!(
3427            &reject_request.rejections,
3428            &[EditPredictionRejection {
3429                request_id: first_id,
3430                reason: EditPredictionRejectReason::Replaced,
3431                was_shown: false
3432            }]
3433        );
3434    }
3435
3436    #[gpui::test]
3437    async fn test_current_preferred(cx: &mut TestAppContext) {
3438        let (zeta, mut requests) = init_test(cx);
3439        let fs = FakeFs::new(cx.executor());
3440        fs.insert_tree(
3441            "/root",
3442            json!({
3443                "foo.md":  "Hello!\nHow\nBye\n"
3444            }),
3445        )
3446        .await;
3447        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3448
3449        let buffer = project
3450            .update(cx, |project, cx| {
3451                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3452                project.open_buffer(path, cx)
3453            })
3454            .await
3455            .unwrap();
3456        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3457        let position = snapshot.anchor_before(language::Point::new(1, 3));
3458
3459        zeta.update(cx, |zeta, cx| {
3460            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3461        });
3462
3463        let (_, respond_tx) = requests.predict.next().await.unwrap();
3464        let first_response = model_response(SIMPLE_DIFF);
3465        let first_id = first_response.id.clone();
3466        respond_tx.send(first_response).unwrap();
3467
3468        cx.run_until_parked();
3469
3470        zeta.read_with(cx, |zeta, cx| {
3471            assert_eq!(
3472                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3473                    .unwrap()
3474                    .id
3475                    .0,
3476                first_id
3477            );
3478        });
3479
3480        // a second request is triggered
3481        zeta.update(cx, |zeta, cx| {
3482            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3483        });
3484
3485        let (_, respond_tx) = requests.predict.next().await.unwrap();
3486        // worse than current prediction
3487        let second_response = model_response(indoc! { r"
3488            --- a/root/foo.md
3489            +++ b/root/foo.md
3490            @@ ... @@
3491             Hello!
3492            -How
3493            +How are
3494             Bye
3495        "});
3496        let second_id = second_response.id.clone();
3497        respond_tx.send(second_response).unwrap();
3498
3499        cx.run_until_parked();
3500
3501        zeta.read_with(cx, |zeta, cx| {
3502            // first is preferred over second
3503            assert_eq!(
3504                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3505                    .unwrap()
3506                    .id
3507                    .0,
3508                first_id
3509            );
3510        });
3511
3512        // second is reported as rejected
3513        let (reject_request, _) = requests.reject.next().await.unwrap();
3514
3515        assert_eq!(
3516            &reject_request.rejections,
3517            &[EditPredictionRejection {
3518                request_id: second_id,
3519                reason: EditPredictionRejectReason::CurrentPreferred,
3520                was_shown: false
3521            }]
3522        );
3523    }
3524
3525    #[gpui::test]
3526    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3527        let (zeta, mut requests) = init_test(cx);
3528        let fs = FakeFs::new(cx.executor());
3529        fs.insert_tree(
3530            "/root",
3531            json!({
3532                "foo.md":  "Hello!\nHow\nBye\n"
3533            }),
3534        )
3535        .await;
3536        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3537
3538        let buffer = project
3539            .update(cx, |project, cx| {
3540                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3541                project.open_buffer(path, cx)
3542            })
3543            .await
3544            .unwrap();
3545        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3546        let position = snapshot.anchor_before(language::Point::new(1, 3));
3547
3548        zeta.update(cx, |zeta, cx| {
3549            // start two refresh tasks
3550            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3551
3552            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3553        });
3554
3555        let (_, respond_first) = requests.predict.next().await.unwrap();
3556        let (_, respond_second) = requests.predict.next().await.unwrap();
3557
3558        // wait for throttle
3559        cx.run_until_parked();
3560
3561        // second responds first
3562        let second_response = model_response(SIMPLE_DIFF);
3563        let second_id = second_response.id.clone();
3564        respond_second.send(second_response).unwrap();
3565
3566        cx.run_until_parked();
3567
3568        zeta.read_with(cx, |zeta, cx| {
3569            // current prediction is second
3570            assert_eq!(
3571                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3572                    .unwrap()
3573                    .id
3574                    .0,
3575                second_id
3576            );
3577        });
3578
3579        let first_response = model_response(SIMPLE_DIFF);
3580        let first_id = first_response.id.clone();
3581        respond_first.send(first_response).unwrap();
3582
3583        cx.run_until_parked();
3584
3585        zeta.read_with(cx, |zeta, cx| {
3586            // current prediction is still second, since first was cancelled
3587            assert_eq!(
3588                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3589                    .unwrap()
3590                    .id
3591                    .0,
3592                second_id
3593            );
3594        });
3595
3596        // first is reported as rejected
3597        let (reject_request, _) = requests.reject.next().await.unwrap();
3598
3599        cx.run_until_parked();
3600
3601        assert_eq!(
3602            &reject_request.rejections,
3603            &[EditPredictionRejection {
3604                request_id: first_id,
3605                reason: EditPredictionRejectReason::Canceled,
3606                was_shown: false
3607            }]
3608        );
3609    }
3610
3611    #[gpui::test]
3612    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3613        let (zeta, mut requests) = init_test(cx);
3614        let fs = FakeFs::new(cx.executor());
3615        fs.insert_tree(
3616            "/root",
3617            json!({
3618                "foo.md":  "Hello!\nHow\nBye\n"
3619            }),
3620        )
3621        .await;
3622        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3623
3624        let buffer = project
3625            .update(cx, |project, cx| {
3626                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3627                project.open_buffer(path, cx)
3628            })
3629            .await
3630            .unwrap();
3631        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3632        let position = snapshot.anchor_before(language::Point::new(1, 3));
3633
3634        zeta.update(cx, |zeta, cx| {
3635            // start two refresh tasks
3636            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3637            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3638        });
3639
3640        // wait for throttle, so requests are sent
3641        cx.run_until_parked();
3642
3643        let (_, respond_first) = requests.predict.next().await.unwrap();
3644        let (_, respond_second) = requests.predict.next().await.unwrap();
3645
3646        zeta.update(cx, |zeta, cx| {
3647            // start a third request
3648            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3649
3650            // 2 are pending, so 2nd is cancelled
3651            assert_eq!(
3652                zeta.get_or_init_zeta_project(&project, cx)
3653                    .cancelled_predictions
3654                    .iter()
3655                    .copied()
3656                    .collect::<Vec<_>>(),
3657                [1]
3658            );
3659        });
3660
3661        // wait for throttle
3662        cx.run_until_parked();
3663
3664        let (_, respond_third) = requests.predict.next().await.unwrap();
3665
3666        let first_response = model_response(SIMPLE_DIFF);
3667        let first_id = first_response.id.clone();
3668        respond_first.send(first_response).unwrap();
3669
3670        cx.run_until_parked();
3671
3672        zeta.read_with(cx, |zeta, cx| {
3673            // current prediction is first
3674            assert_eq!(
3675                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3676                    .unwrap()
3677                    .id
3678                    .0,
3679                first_id
3680            );
3681        });
3682
3683        let cancelled_response = model_response(SIMPLE_DIFF);
3684        let cancelled_id = cancelled_response.id.clone();
3685        respond_second.send(cancelled_response).unwrap();
3686
3687        cx.run_until_parked();
3688
3689        zeta.read_with(cx, |zeta, cx| {
3690            // current prediction is still first, since second was cancelled
3691            assert_eq!(
3692                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3693                    .unwrap()
3694                    .id
3695                    .0,
3696                first_id
3697            );
3698        });
3699
3700        let third_response = model_response(SIMPLE_DIFF);
3701        let third_response_id = third_response.id.clone();
3702        respond_third.send(third_response).unwrap();
3703
3704        cx.run_until_parked();
3705
3706        zeta.read_with(cx, |zeta, cx| {
3707            // third completes and replaces first
3708            assert_eq!(
3709                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3710                    .unwrap()
3711                    .id
3712                    .0,
3713                third_response_id
3714            );
3715        });
3716
3717        // second is reported as rejected
3718        let (reject_request, _) = requests.reject.next().await.unwrap();
3719
3720        cx.run_until_parked();
3721
3722        assert_eq!(
3723            &reject_request.rejections,
3724            &[
3725                EditPredictionRejection {
3726                    request_id: cancelled_id,
3727                    reason: EditPredictionRejectReason::Canceled,
3728                    was_shown: false
3729                },
3730                EditPredictionRejection {
3731                    request_id: first_id,
3732                    reason: EditPredictionRejectReason::Replaced,
3733                    was_shown: false
3734                }
3735            ]
3736        );
3737    }
3738
3739    // Skipped until we start including diagnostics in prompt
3740    // #[gpui::test]
3741    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3742    //     let (zeta, mut req_rx) = init_test(cx);
3743    //     let fs = FakeFs::new(cx.executor());
3744    //     fs.insert_tree(
3745    //         "/root",
3746    //         json!({
3747    //             "foo.md": "Hello!\nBye"
3748    //         }),
3749    //     )
3750    //     .await;
3751    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3752
3753    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3754    //     let diagnostic = lsp::Diagnostic {
3755    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3756    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3757    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3758    //         ..Default::default()
3759    //     };
3760
3761    //     project.update(cx, |project, cx| {
3762    //         project.lsp_store().update(cx, |lsp_store, cx| {
3763    //             // Create some diagnostics
3764    //             lsp_store
3765    //                 .update_diagnostics(
3766    //                     LanguageServerId(0),
3767    //                     lsp::PublishDiagnosticsParams {
3768    //                         uri: path_to_buffer_uri.clone(),
3769    //                         diagnostics: vec![diagnostic],
3770    //                         version: None,
3771    //                     },
3772    //                     None,
3773    //                     language::DiagnosticSourceKind::Pushed,
3774    //                     &[],
3775    //                     cx,
3776    //                 )
3777    //                 .unwrap();
3778    //         });
3779    //     });
3780
3781    //     let buffer = project
3782    //         .update(cx, |project, cx| {
3783    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3784    //             project.open_buffer(path, cx)
3785    //         })
3786    //         .await
3787    //         .unwrap();
3788
3789    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3790    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3791
3792    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3793    //         zeta.request_prediction(&project, &buffer, position, cx)
3794    //     });
3795
3796    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3797
3798    //     assert_eq!(request.diagnostic_groups.len(), 1);
3799    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3800    //         .unwrap();
3801    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3802    //     assert_eq!(
3803    //         value,
3804    //         json!({
3805    //             "entries": [{
3806    //                 "range": {
3807    //                     "start": 8,
3808    //                     "end": 10
3809    //                 },
3810    //                 "diagnostic": {
3811    //                     "source": null,
3812    //                     "code": null,
3813    //                     "code_description": null,
3814    //                     "severity": 1,
3815    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3816    //                     "markdown": null,
3817    //                     "group_id": 0,
3818    //                     "is_primary": true,
3819    //                     "is_disk_based": false,
3820    //                     "is_unnecessary": false,
3821    //                     "source_kind": "Pushed",
3822    //                     "data": null,
3823    //                     "underline": true
3824    //                 }
3825    //             }],
3826    //             "primary_ix": 0
3827    //         })
3828    //     );
3829    // }
3830
3831    fn model_response(text: &str) -> open_ai::Response {
3832        open_ai::Response {
3833            id: Uuid::new_v4().to_string(),
3834            object: "response".into(),
3835            created: 0,
3836            model: "model".into(),
3837            choices: vec![open_ai::Choice {
3838                index: 0,
3839                message: open_ai::RequestMessage::Assistant {
3840                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3841                    tool_calls: vec![],
3842                },
3843                finish_reason: None,
3844            }],
3845            usage: Usage {
3846                prompt_tokens: 0,
3847                completion_tokens: 0,
3848                total_tokens: 0,
3849            },
3850        }
3851    }
3852
3853    fn prompt_from_request(request: &open_ai::Request) -> &str {
3854        assert_eq!(request.messages.len(), 1);
3855        let open_ai::RequestMessage::User {
3856            content: open_ai::MessageContent::Plain(content),
3857            ..
3858        } = &request.messages[0]
3859        else {
3860            panic!(
3861                "Request does not have single user message of type Plain. {:#?}",
3862                request
3863            );
3864        };
3865        content
3866    }
3867
3868    struct RequestChannels {
3869        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3870        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3871    }
3872
3873    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3874        cx.update(move |cx| {
3875            let settings_store = SettingsStore::test(cx);
3876            cx.set_global(settings_store);
3877            zlog::init_test();
3878
3879            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3880            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3881
3882            let http_client = FakeHttpClient::create({
3883                move |req| {
3884                    let uri = req.uri().path().to_string();
3885                    let mut body = req.into_body();
3886                    let predict_req_tx = predict_req_tx.clone();
3887                    let reject_req_tx = reject_req_tx.clone();
3888                    async move {
3889                        let resp = match uri.as_str() {
3890                            "/client/llm_tokens" => serde_json::to_string(&json!({
3891                                "token": "test"
3892                            }))
3893                            .unwrap(),
3894                            "/predict_edits/raw" => {
3895                                let mut buf = Vec::new();
3896                                body.read_to_end(&mut buf).await.ok();
3897                                let req = serde_json::from_slice(&buf).unwrap();
3898
3899                                let (res_tx, res_rx) = oneshot::channel();
3900                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3901                                serde_json::to_string(&res_rx.await?).unwrap()
3902                            }
3903                            "/predict_edits/reject" => {
3904                                let mut buf = Vec::new();
3905                                body.read_to_end(&mut buf).await.ok();
3906                                let req = serde_json::from_slice(&buf).unwrap();
3907
3908                                let (res_tx, res_rx) = oneshot::channel();
3909                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3910                                serde_json::to_string(&res_rx.await?).unwrap()
3911                            }
3912                            _ => {
3913                                panic!("Unexpected path: {}", uri)
3914                            }
3915                        };
3916
3917                        Ok(Response::builder().body(resp.into()).unwrap())
3918                    }
3919                }
3920            });
3921
3922            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3923            client.cloud_client().set_credentials(1, "test".into());
3924
3925            language_model::init(client.clone(), cx);
3926
3927            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3928            let zeta = Zeta::global(&client, &user_store, cx);
3929
3930            (
3931                zeta,
3932                RequestChannels {
3933                    predict: predict_req_rx,
3934                    reject: reject_req_rx,
3935                },
3936            )
3937        })
3938    }
3939}