zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::{HashMap, HashSet};
  13use command_palette_hooks::CommandPaletteFilter;
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{
  16    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  17    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  18    SyntaxIndex, SyntaxIndexState,
  19};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  21use futures::channel::{mpsc, oneshot};
  22use futures::{AsyncReadExt as _, StreamExt as _};
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::{
  29    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  30};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, RefreshLlmTokenListener};
  33use open_ai::FunctionDefinition;
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  39use std::any::{Any as _, TypeId};
  40use std::collections::{VecDeque, hash_map};
  41use telemetry_events::EditPredictionRating;
  42use workspace::Workspace;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::rel_path::RelPathBuf;
  53use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  55
  56pub mod assemble_excerpts;
  57mod license_detection;
  58mod onboarding_modal;
  59mod prediction;
  60mod provider;
  61mod rate_prediction_modal;
  62pub mod retrieval_search;
  63mod sweep_ai;
  64pub mod udiff;
  65mod xml_edits;
  66pub mod zeta1;
  67
  68#[cfg(test)]
  69mod zeta_tests;
  70
  71use crate::assemble_excerpts::assemble_excerpts;
  72use crate::license_detection::LicenseDetectionWatcher;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76pub use crate::prediction::EditPredictionInputs;
  77use crate::prediction::EditPredictionResult;
  78use crate::rate_prediction_modal::{
  79    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  80    ThumbsUpActivePrediction,
  81};
  82use crate::sweep_ai::SweepAi;
  83use crate::zeta1::request_prediction_with_zeta1;
  84pub use provider::ZetaEditPredictionProvider;
  85
  86actions!(
  87    edit_prediction,
  88    [
  89        /// Resets the edit prediction onboarding state.
  90        ResetOnboarding,
  91        /// Opens the rate completions modal.
  92        RateCompletions,
  93        /// Clears the edit prediction history.
  94        ClearHistory,
  95    ]
  96);
  97
  98/// Maximum number of events to track.
  99const EVENT_COUNT_MAX: usize = 6;
 100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 102
 103pub struct SweepFeatureFlag;
 104
 105impl FeatureFlag for SweepFeatureFlag {
 106    const NAME: &str = "sweep-ai";
 107}
 108pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 109    max_bytes: 512,
 110    min_bytes: 128,
 111    target_before_cursor_over_total_bytes: 0.5,
 112};
 113
 114pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 115    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 116
 117pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 118    excerpt: DEFAULT_EXCERPT_OPTIONS,
 119};
 120
 121pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 122    EditPredictionContextOptions {
 123        use_imports: true,
 124        max_retrieved_declarations: 0,
 125        excerpt: DEFAULT_EXCERPT_OPTIONS,
 126        score: EditPredictionScoreOptions {
 127            omit_excerpt_overlaps: true,
 128        },
 129    };
 130
 131pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 132    context: DEFAULT_CONTEXT_OPTIONS,
 133    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 134    max_diagnostic_bytes: 2048,
 135    prompt_format: PromptFormat::DEFAULT,
 136    file_indexing_parallelism: 1,
 137    buffer_change_grouping_interval: Duration::from_secs(1),
 138};
 139
 140static USE_OLLAMA: LazyLock<bool> =
 141    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 142static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 143    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 144        "qwen3-coder:30b".to_string()
 145    } else {
 146        "yqvev8r3".to_string()
 147    })
 148});
 149static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 150    match env::var("ZED_ZETA2_MODEL").as_deref() {
 151        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 152        Ok(model) => model,
 153        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 154        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 155    }
 156    .to_string()
 157});
 158static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 159    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 160        if *USE_OLLAMA {
 161            Some("http://localhost:11434/v1/chat/completions".into())
 162        } else {
 163            None
 164        }
 165    })
 166});
 167
 168pub struct Zeta2FeatureFlag;
 169
 170impl FeatureFlag for Zeta2FeatureFlag {
 171    const NAME: &'static str = "zeta2";
 172
 173    fn enabled_for_staff() -> bool {
 174        true
 175    }
 176}
 177
 178#[derive(Clone)]
 179struct ZetaGlobal(Entity<Zeta>);
 180
 181impl Global for ZetaGlobal {}
 182
 183pub struct Zeta {
 184    client: Arc<Client>,
 185    user_store: Entity<UserStore>,
 186    llm_token: LlmApiToken,
 187    _llm_token_subscription: Subscription,
 188    projects: HashMap<EntityId, ZetaProject>,
 189    options: ZetaOptions,
 190    update_required: bool,
 191    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 192    #[cfg(feature = "eval-support")]
 193    eval_cache: Option<Arc<dyn EvalCache>>,
 194    edit_prediction_model: ZetaEditPredictionModel,
 195    sweep_ai: SweepAi,
 196    data_collection_choice: DataCollectionChoice,
 197    rejected_predictions: Vec<EditPredictionRejection>,
 198    reject_predictions_tx: mpsc::UnboundedSender<()>,
 199    reject_predictions_debounce_task: Option<Task<()>>,
 200    shown_predictions: VecDeque<EditPrediction>,
 201    rated_predictions: HashSet<EditPredictionId>,
 202}
 203
 204#[derive(Copy, Clone, Default, PartialEq, Eq)]
 205pub enum ZetaEditPredictionModel {
 206    #[default]
 207    Zeta1,
 208    Zeta2,
 209    Sweep,
 210}
 211
 212#[derive(Debug, Clone, PartialEq)]
 213pub struct ZetaOptions {
 214    pub context: ContextMode,
 215    pub max_prompt_bytes: usize,
 216    pub max_diagnostic_bytes: usize,
 217    pub prompt_format: predict_edits_v3::PromptFormat,
 218    pub file_indexing_parallelism: usize,
 219    pub buffer_change_grouping_interval: Duration,
 220}
 221
 222#[derive(Debug, Clone, PartialEq)]
 223pub enum ContextMode {
 224    Agentic(AgenticContextOptions),
 225    Syntax(EditPredictionContextOptions),
 226}
 227
 228#[derive(Debug, Clone, PartialEq)]
 229pub struct AgenticContextOptions {
 230    pub excerpt: EditPredictionExcerptOptions,
 231}
 232
 233impl ContextMode {
 234    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 235        match self {
 236            ContextMode::Agentic(options) => &options.excerpt,
 237            ContextMode::Syntax(options) => &options.excerpt,
 238        }
 239    }
 240}
 241
 242#[derive(Debug)]
 243pub enum ZetaDebugInfo {
 244    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 245    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 246    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 247    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 248    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 249}
 250
 251#[derive(Debug)]
 252pub struct ZetaContextRetrievalStartedDebugInfo {
 253    pub project: Entity<Project>,
 254    pub timestamp: Instant,
 255    pub search_prompt: String,
 256}
 257
 258#[derive(Debug)]
 259pub struct ZetaContextRetrievalDebugInfo {
 260    pub project: Entity<Project>,
 261    pub timestamp: Instant,
 262}
 263
 264#[derive(Debug)]
 265pub struct ZetaEditPredictionDebugInfo {
 266    pub inputs: EditPredictionInputs,
 267    pub retrieval_time: Duration,
 268    pub buffer: WeakEntity<Buffer>,
 269    pub position: language::Anchor,
 270    pub local_prompt: Result<String, String>,
 271    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 272}
 273
 274#[derive(Debug)]
 275pub struct ZetaSearchQueryDebugInfo {
 276    pub project: Entity<Project>,
 277    pub timestamp: Instant,
 278    pub search_queries: Vec<SearchToolQuery>,
 279}
 280
 281pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 282
 283struct ZetaProject {
 284    syntax_index: Option<Entity<SyntaxIndex>>,
 285    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 286    last_event: Option<LastEvent>,
 287    recent_paths: VecDeque<ProjectPath>,
 288    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 289    current_prediction: Option<CurrentEditPrediction>,
 290    next_pending_prediction_id: usize,
 291    pending_predictions: ArrayVec<PendingPrediction, 2>,
 292    last_prediction_refresh: Option<(EntityId, Instant)>,
 293    cancelled_predictions: HashSet<usize>,
 294    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 295    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 296    refresh_context_debounce_task: Option<Task<Option<()>>>,
 297    refresh_context_timestamp: Option<Instant>,
 298    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 299    _subscription: gpui::Subscription,
 300}
 301
 302impl ZetaProject {
 303    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 304        self.events
 305            .iter()
 306            .cloned()
 307            .chain(
 308                self.last_event
 309                    .as_ref()
 310                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 311            )
 312            .collect()
 313    }
 314
 315    fn cancel_pending_prediction(
 316        &mut self,
 317        pending_prediction: PendingPrediction,
 318        cx: &mut Context<Zeta>,
 319    ) {
 320        self.cancelled_predictions.insert(pending_prediction.id);
 321
 322        cx.spawn(async move |this, cx| {
 323            let Some(prediction_id) = pending_prediction.task.await else {
 324                return;
 325            };
 326
 327            this.update(cx, |this, cx| {
 328                this.reject_prediction(
 329                    prediction_id,
 330                    EditPredictionRejectReason::Canceled,
 331                    false,
 332                    cx,
 333                );
 334            })
 335            .ok();
 336        })
 337        .detach()
 338    }
 339}
 340
 341#[derive(Debug, Clone)]
 342struct CurrentEditPrediction {
 343    pub requested_by: PredictionRequestedBy,
 344    pub prediction: EditPrediction,
 345    pub was_shown: bool,
 346}
 347
 348impl CurrentEditPrediction {
 349    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 350        let Some(new_edits) = self
 351            .prediction
 352            .interpolate(&self.prediction.buffer.read(cx))
 353        else {
 354            return false;
 355        };
 356
 357        if self.prediction.buffer != old_prediction.prediction.buffer {
 358            return true;
 359        }
 360
 361        let Some(old_edits) = old_prediction
 362            .prediction
 363            .interpolate(&old_prediction.prediction.buffer.read(cx))
 364        else {
 365            return true;
 366        };
 367
 368        let requested_by_buffer_id = self.requested_by.buffer_id();
 369
 370        // This reduces the occurrence of UI thrash from replacing edits
 371        //
 372        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 373        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 374            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 375            && old_edits.len() == 1
 376            && new_edits.len() == 1
 377        {
 378            let (old_range, old_text) = &old_edits[0];
 379            let (new_range, new_text) = &new_edits[0];
 380            new_range == old_range && new_text.starts_with(old_text.as_ref())
 381        } else {
 382            true
 383        }
 384    }
 385}
 386
 387#[derive(Debug, Clone)]
 388enum PredictionRequestedBy {
 389    DiagnosticsUpdate,
 390    Buffer(EntityId),
 391}
 392
 393impl PredictionRequestedBy {
 394    pub fn buffer_id(&self) -> Option<EntityId> {
 395        match self {
 396            PredictionRequestedBy::DiagnosticsUpdate => None,
 397            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 398        }
 399    }
 400}
 401
 402#[derive(Debug)]
 403struct PendingPrediction {
 404    id: usize,
 405    task: Task<Option<EditPredictionId>>,
 406}
 407
 408/// A prediction from the perspective of a buffer.
 409#[derive(Debug)]
 410enum BufferEditPrediction<'a> {
 411    Local { prediction: &'a EditPrediction },
 412    Jump { prediction: &'a EditPrediction },
 413}
 414
 415#[cfg(test)]
 416impl std::ops::Deref for BufferEditPrediction<'_> {
 417    type Target = EditPrediction;
 418
 419    fn deref(&self) -> &Self::Target {
 420        match self {
 421            BufferEditPrediction::Local { prediction } => prediction,
 422            BufferEditPrediction::Jump { prediction } => prediction,
 423        }
 424    }
 425}
 426
 427struct RegisteredBuffer {
 428    snapshot: BufferSnapshot,
 429    _subscriptions: [gpui::Subscription; 2],
 430}
 431
 432struct LastEvent {
 433    old_snapshot: BufferSnapshot,
 434    new_snapshot: BufferSnapshot,
 435    end_edit_anchor: Option<Anchor>,
 436}
 437
 438impl LastEvent {
 439    pub fn finalize(
 440        &self,
 441        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 442        cx: &App,
 443    ) -> Option<Arc<predict_edits_v3::Event>> {
 444        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 445        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 446
 447        let file = self.new_snapshot.file();
 448        let old_file = self.old_snapshot.file();
 449
 450        let in_open_source_repo = [file, old_file].iter().all(|file| {
 451            file.is_some_and(|file| {
 452                license_detection_watchers
 453                    .get(&file.worktree_id(cx))
 454                    .is_some_and(|watcher| watcher.is_project_open_source())
 455            })
 456        });
 457
 458        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 459
 460        if path == old_path && diff.is_empty() {
 461            None
 462        } else {
 463            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 464                old_path,
 465                path,
 466                diff,
 467                in_open_source_repo,
 468                // TODO: Actually detect if this edit was predicted or not
 469                predicted: false,
 470            }))
 471        }
 472    }
 473}
 474
 475fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 476    if let Some(file) = snapshot.file() {
 477        file.full_path(cx).into()
 478    } else {
 479        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 480    }
 481}
 482
 483impl Zeta {
 484    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 485        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 486    }
 487
 488    pub fn global(
 489        client: &Arc<Client>,
 490        user_store: &Entity<UserStore>,
 491        cx: &mut App,
 492    ) -> Entity<Self> {
 493        cx.try_global::<ZetaGlobal>()
 494            .map(|global| global.0.clone())
 495            .unwrap_or_else(|| {
 496                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 497                cx.set_global(ZetaGlobal(zeta.clone()));
 498                zeta
 499            })
 500    }
 501
 502    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 503        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 504        let data_collection_choice = Self::load_data_collection_choice();
 505
 506        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 507        cx.spawn(async move |this, cx| {
 508            while let Some(()) = reject_rx.next().await {
 509                this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
 510                    .await
 511                    .log_err();
 512            }
 513            anyhow::Ok(())
 514        })
 515        .detach();
 516
 517        Self {
 518            projects: HashMap::default(),
 519            client,
 520            user_store,
 521            options: DEFAULT_OPTIONS,
 522            llm_token: LlmApiToken::default(),
 523            _llm_token_subscription: cx.subscribe(
 524                &refresh_llm_token_listener,
 525                |this, _listener, _event, cx| {
 526                    let client = this.client.clone();
 527                    let llm_token = this.llm_token.clone();
 528                    cx.spawn(async move |_this, _cx| {
 529                        llm_token.refresh(&client).await?;
 530                        anyhow::Ok(())
 531                    })
 532                    .detach_and_log_err(cx);
 533                },
 534            ),
 535            update_required: false,
 536            debug_tx: None,
 537            #[cfg(feature = "eval-support")]
 538            eval_cache: None,
 539            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 540            sweep_ai: SweepAi::new(cx),
 541            data_collection_choice,
 542            rejected_predictions: Vec::new(),
 543            reject_predictions_debounce_task: None,
 544            reject_predictions_tx: reject_tx,
 545            rated_predictions: Default::default(),
 546            shown_predictions: Default::default(),
 547        }
 548    }
 549
 550    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 551        self.edit_prediction_model = model;
 552    }
 553
 554    pub fn has_sweep_api_token(&self) -> bool {
 555        self.sweep_ai.api_token.is_some()
 556    }
 557
 558    #[cfg(feature = "eval-support")]
 559    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 560        self.eval_cache = Some(cache);
 561    }
 562
 563    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 564        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 565        self.debug_tx = Some(debug_watch_tx);
 566        debug_watch_rx
 567    }
 568
 569    pub fn options(&self) -> &ZetaOptions {
 570        &self.options
 571    }
 572
 573    pub fn set_options(&mut self, options: ZetaOptions) {
 574        self.options = options;
 575    }
 576
 577    pub fn clear_history(&mut self) {
 578        for zeta_project in self.projects.values_mut() {
 579            zeta_project.events.clear();
 580        }
 581    }
 582
 583    pub fn context_for_project(
 584        &self,
 585        project: &Entity<Project>,
 586    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 587        self.projects
 588            .get(&project.entity_id())
 589            .and_then(|project| {
 590                Some(
 591                    project
 592                        .context
 593                        .as_ref()?
 594                        .iter()
 595                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 596                )
 597            })
 598            .into_iter()
 599            .flatten()
 600    }
 601
 602    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 603        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 604            self.user_store.read(cx).edit_prediction_usage()
 605        } else {
 606            None
 607        }
 608    }
 609
 610    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 611        self.get_or_init_zeta_project(project, cx);
 612    }
 613
 614    pub fn register_buffer(
 615        &mut self,
 616        buffer: &Entity<Buffer>,
 617        project: &Entity<Project>,
 618        cx: &mut Context<Self>,
 619    ) {
 620        let zeta_project = self.get_or_init_zeta_project(project, cx);
 621        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 622    }
 623
 624    fn get_or_init_zeta_project(
 625        &mut self,
 626        project: &Entity<Project>,
 627        cx: &mut Context<Self>,
 628    ) -> &mut ZetaProject {
 629        self.projects
 630            .entry(project.entity_id())
 631            .or_insert_with(|| ZetaProject {
 632                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 633                    Some(cx.new(|cx| {
 634                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 635                    }))
 636                } else {
 637                    None
 638                },
 639                events: VecDeque::new(),
 640                last_event: None,
 641                recent_paths: VecDeque::new(),
 642                registered_buffers: HashMap::default(),
 643                current_prediction: None,
 644                cancelled_predictions: HashSet::default(),
 645                pending_predictions: ArrayVec::new(),
 646                next_pending_prediction_id: 0,
 647                last_prediction_refresh: None,
 648                context: None,
 649                refresh_context_task: None,
 650                refresh_context_debounce_task: None,
 651                refresh_context_timestamp: None,
 652                license_detection_watchers: HashMap::default(),
 653                _subscription: cx.subscribe(&project, Self::handle_project_event),
 654            })
 655    }
 656
 657    fn handle_project_event(
 658        &mut self,
 659        project: Entity<Project>,
 660        event: &project::Event,
 661        cx: &mut Context<Self>,
 662    ) {
 663        // TODO [zeta2] init with recent paths
 664        match event {
 665            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 666                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 667                    return;
 668                };
 669                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 670                if let Some(path) = path {
 671                    if let Some(ix) = zeta_project
 672                        .recent_paths
 673                        .iter()
 674                        .position(|probe| probe == &path)
 675                    {
 676                        zeta_project.recent_paths.remove(ix);
 677                    }
 678                    zeta_project.recent_paths.push_front(path);
 679                }
 680            }
 681            project::Event::DiagnosticsUpdated { .. } => {
 682                if cx.has_flag::<Zeta2FeatureFlag>() {
 683                    self.refresh_prediction_from_diagnostics(project, cx);
 684                }
 685            }
 686            _ => (),
 687        }
 688    }
 689
 690    fn register_buffer_impl<'a>(
 691        zeta_project: &'a mut ZetaProject,
 692        buffer: &Entity<Buffer>,
 693        project: &Entity<Project>,
 694        cx: &mut Context<Self>,
 695    ) -> &'a mut RegisteredBuffer {
 696        let buffer_id = buffer.entity_id();
 697
 698        if let Some(file) = buffer.read(cx).file() {
 699            let worktree_id = file.worktree_id(cx);
 700            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 701                zeta_project
 702                    .license_detection_watchers
 703                    .entry(worktree_id)
 704                    .or_insert_with(|| {
 705                        let project_entity_id = project.entity_id();
 706                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 707                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 708                            else {
 709                                return;
 710                            };
 711                            zeta_project.license_detection_watchers.remove(&worktree_id);
 712                        })
 713                        .detach();
 714                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 715                    });
 716            }
 717        }
 718
 719        match zeta_project.registered_buffers.entry(buffer_id) {
 720            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 721            hash_map::Entry::Vacant(entry) => {
 722                let snapshot = buffer.read(cx).snapshot();
 723                let project_entity_id = project.entity_id();
 724                entry.insert(RegisteredBuffer {
 725                    snapshot,
 726                    _subscriptions: [
 727                        cx.subscribe(buffer, {
 728                            let project = project.downgrade();
 729                            move |this, buffer, event, cx| {
 730                                if let language::BufferEvent::Edited = event
 731                                    && let Some(project) = project.upgrade()
 732                                {
 733                                    this.report_changes_for_buffer(&buffer, &project, cx);
 734                                }
 735                            }
 736                        }),
 737                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 738                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 739                            else {
 740                                return;
 741                            };
 742                            zeta_project.registered_buffers.remove(&buffer_id);
 743                        }),
 744                    ],
 745                })
 746            }
 747        }
 748    }
 749
 750    fn report_changes_for_buffer(
 751        &mut self,
 752        buffer: &Entity<Buffer>,
 753        project: &Entity<Project>,
 754        cx: &mut Context<Self>,
 755    ) {
 756        let project_state = self.get_or_init_zeta_project(project, cx);
 757        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 758
 759        let new_snapshot = buffer.read(cx).snapshot();
 760        if new_snapshot.version == registered_buffer.snapshot.version {
 761            return;
 762        }
 763
 764        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 765        let end_edit_anchor = new_snapshot
 766            .anchored_edits_since::<Point>(&old_snapshot.version)
 767            .last()
 768            .map(|(_, range)| range.end);
 769        let events = &mut project_state.events;
 770
 771        if let Some(LastEvent {
 772            new_snapshot: last_new_snapshot,
 773            end_edit_anchor: last_end_edit_anchor,
 774            ..
 775        }) = project_state.last_event.as_mut()
 776        {
 777            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 778                == last_new_snapshot.remote_id()
 779                && old_snapshot.version == last_new_snapshot.version;
 780
 781            let should_coalesce = is_next_snapshot_of_same_buffer
 782                && end_edit_anchor
 783                    .as_ref()
 784                    .zip(last_end_edit_anchor.as_ref())
 785                    .is_some_and(|(a, b)| {
 786                        let a = a.to_point(&new_snapshot);
 787                        let b = b.to_point(&new_snapshot);
 788                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 789                    });
 790
 791            if should_coalesce {
 792                *last_end_edit_anchor = end_edit_anchor;
 793                *last_new_snapshot = new_snapshot;
 794                return;
 795            }
 796        }
 797
 798        if events.len() + 1 >= EVENT_COUNT_MAX {
 799            events.pop_front();
 800        }
 801
 802        if let Some(event) = project_state.last_event.take() {
 803            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 804        }
 805
 806        project_state.last_event = Some(LastEvent {
 807            old_snapshot,
 808            new_snapshot,
 809            end_edit_anchor,
 810        });
 811    }
 812
 813    fn current_prediction_for_buffer(
 814        &self,
 815        buffer: &Entity<Buffer>,
 816        project: &Entity<Project>,
 817        cx: &App,
 818    ) -> Option<BufferEditPrediction<'_>> {
 819        let project_state = self.projects.get(&project.entity_id())?;
 820
 821        let CurrentEditPrediction {
 822            requested_by,
 823            prediction,
 824            ..
 825        } = project_state.current_prediction.as_ref()?;
 826
 827        if prediction.targets_buffer(buffer.read(cx)) {
 828            Some(BufferEditPrediction::Local { prediction })
 829        } else {
 830            let show_jump = match requested_by {
 831                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 832                    requested_by_buffer_id == &buffer.entity_id()
 833                }
 834                PredictionRequestedBy::DiagnosticsUpdate => true,
 835            };
 836
 837            if show_jump {
 838                Some(BufferEditPrediction::Jump { prediction })
 839            } else {
 840                None
 841            }
 842        }
 843    }
 844
 845    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 846        match self.edit_prediction_model {
 847            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 848            ZetaEditPredictionModel::Sweep => return,
 849        }
 850
 851        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 852            return;
 853        };
 854
 855        let Some(prediction) = project_state.current_prediction.take() else {
 856            return;
 857        };
 858        let request_id = prediction.prediction.id.to_string();
 859        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 860            project_state.cancel_pending_prediction(pending_prediction, cx);
 861        }
 862
 863        let client = self.client.clone();
 864        let llm_token = self.llm_token.clone();
 865        let app_version = AppVersion::global(cx);
 866        cx.spawn(async move |this, cx| {
 867            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 868                http_client::Url::parse(&predict_edits_url)?
 869            } else {
 870                client
 871                    .http_client()
 872                    .build_zed_llm_url("/predict_edits/accept", &[])?
 873            };
 874
 875            let response = cx
 876                .background_spawn(Self::send_api_request::<()>(
 877                    move |builder| {
 878                        let req = builder.uri(url.as_ref()).body(
 879                            serde_json::to_string(&AcceptEditPredictionBody {
 880                                request_id: request_id.clone(),
 881                            })?
 882                            .into(),
 883                        );
 884                        Ok(req?)
 885                    },
 886                    client,
 887                    llm_token,
 888                    app_version,
 889                ))
 890                .await;
 891
 892            Self::handle_api_response(&this, response, cx)?;
 893            anyhow::Ok(())
 894        })
 895        .detach_and_log_err(cx);
 896    }
 897
 898    fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 899        match self.edit_prediction_model {
 900            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 901            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 902        }
 903
 904        let client = self.client.clone();
 905        let llm_token = self.llm_token.clone();
 906        let app_version = AppVersion::global(cx);
 907        let last_rejection = self.rejected_predictions.last().cloned();
 908        let Some(last_rejection) = last_rejection else {
 909            return Task::ready(anyhow::Ok(()));
 910        };
 911
 912        let body = serde_json::to_string(&RejectEditPredictionsBody {
 913            rejections: self.rejected_predictions.clone(),
 914        })
 915        .ok();
 916
 917        cx.spawn(async move |this, cx| {
 918            let url = client
 919                .http_client()
 920                .build_zed_llm_url("/predict_edits/reject", &[])?;
 921
 922            cx.background_spawn(Self::send_api_request::<()>(
 923                move |builder| {
 924                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 925                    Ok(req?)
 926                },
 927                client,
 928                llm_token,
 929                app_version,
 930            ))
 931            .await
 932            .context("Failed to reject edit predictions")?;
 933
 934            this.update(cx, |this, _| {
 935                if let Some(ix) = this
 936                    .rejected_predictions
 937                    .iter()
 938                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 939                {
 940                    this.rejected_predictions.drain(..ix + 1);
 941                }
 942            })
 943        })
 944    }
 945
 946    fn reject_current_prediction(
 947        &mut self,
 948        reason: EditPredictionRejectReason,
 949        project: &Entity<Project>,
 950        cx: &mut Context<Self>,
 951    ) {
 952        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 953            project_state.pending_predictions.clear();
 954            if let Some(prediction) = project_state.current_prediction.take() {
 955                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
 956            }
 957        };
 958    }
 959
 960    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 961        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 962            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 963                if !current_prediction.was_shown {
 964                    current_prediction.was_shown = true;
 965                    self.shown_predictions
 966                        .push_front(current_prediction.prediction.clone());
 967                    if self.shown_predictions.len() > 50 {
 968                        let completion = self.shown_predictions.pop_back().unwrap();
 969                        self.rated_predictions.remove(&completion.id);
 970                    }
 971                }
 972            }
 973        }
 974    }
 975
 976    fn reject_prediction(
 977        &mut self,
 978        prediction_id: EditPredictionId,
 979        reason: EditPredictionRejectReason,
 980        was_shown: bool,
 981        cx: &mut Context<Self>,
 982    ) {
 983        self.rejected_predictions.push(EditPredictionRejection {
 984            request_id: prediction_id.to_string(),
 985            reason,
 986            was_shown,
 987        });
 988
 989        let reached_request_limit =
 990            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
 991        let reject_tx = self.reject_predictions_tx.clone();
 992        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 993            const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 994            if !reached_request_limit {
 995                cx.background_executor()
 996                    .timer(REJECT_REQUEST_DEBOUNCE)
 997                    .await;
 998            }
 999            reject_tx.unbounded_send(()).log_err();
1000        }));
1001    }
1002
1003    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1004        self.projects
1005            .get(&project.entity_id())
1006            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1007    }
1008
1009    pub fn refresh_prediction_from_buffer(
1010        &mut self,
1011        project: Entity<Project>,
1012        buffer: Entity<Buffer>,
1013        position: language::Anchor,
1014        cx: &mut Context<Self>,
1015    ) {
1016        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1017            let Some(request_task) = this
1018                .update(cx, |this, cx| {
1019                    this.request_prediction(&project, &buffer, position, cx)
1020                })
1021                .log_err()
1022            else {
1023                return Task::ready(anyhow::Ok(None));
1024            };
1025
1026            cx.spawn(async move |_cx| {
1027                request_task.await.map(|prediction_result| {
1028                    prediction_result.map(|prediction_result| {
1029                        (
1030                            prediction_result,
1031                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1032                        )
1033                    })
1034                })
1035            })
1036        })
1037    }
1038
1039    pub fn refresh_prediction_from_diagnostics(
1040        &mut self,
1041        project: Entity<Project>,
1042        cx: &mut Context<Self>,
1043    ) {
1044        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1045            return;
1046        };
1047
1048        // Prefer predictions from buffer
1049        if zeta_project.current_prediction.is_some() {
1050            return;
1051        };
1052
1053        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1054            let Some(open_buffer_task) = project
1055                .update(cx, |project, cx| {
1056                    project
1057                        .active_entry()
1058                        .and_then(|entry| project.path_for_entry(entry, cx))
1059                        .map(|path| project.open_buffer(path, cx))
1060                })
1061                .log_err()
1062                .flatten()
1063            else {
1064                return Task::ready(anyhow::Ok(None));
1065            };
1066
1067            cx.spawn(async move |cx| {
1068                let active_buffer = open_buffer_task.await?;
1069                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1070
1071                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1072                    active_buffer,
1073                    &snapshot,
1074                    Default::default(),
1075                    Default::default(),
1076                    &project,
1077                    cx,
1078                )
1079                .await?
1080                else {
1081                    return anyhow::Ok(None);
1082                };
1083
1084                let Some(prediction_result) = this
1085                    .update(cx, |this, cx| {
1086                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
1087                    })?
1088                    .await?
1089                else {
1090                    return anyhow::Ok(None);
1091                };
1092
1093                this.update(cx, |this, cx| {
1094                    Some((
1095                        if this
1096                            .get_or_init_zeta_project(&project, cx)
1097                            .current_prediction
1098                            .is_none()
1099                        {
1100                            prediction_result
1101                        } else {
1102                            EditPredictionResult {
1103                                id: prediction_result.id,
1104                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1105                            }
1106                        },
1107                        PredictionRequestedBy::DiagnosticsUpdate,
1108                    ))
1109                })
1110            })
1111        });
1112    }
1113
1114    #[cfg(not(test))]
1115    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1116    #[cfg(test)]
1117    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1118
1119    fn queue_prediction_refresh(
1120        &mut self,
1121        project: Entity<Project>,
1122        throttle_entity: EntityId,
1123        cx: &mut Context<Self>,
1124        do_refresh: impl FnOnce(
1125            WeakEntity<Self>,
1126            &mut AsyncApp,
1127        )
1128            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1129        + 'static,
1130    ) {
1131        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1132        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1133        zeta_project.next_pending_prediction_id += 1;
1134        let last_request = zeta_project.last_prediction_refresh;
1135
1136        let task = cx.spawn(async move |this, cx| {
1137            if let Some((last_entity, last_timestamp)) = last_request
1138                && throttle_entity == last_entity
1139                && let Some(timeout) =
1140                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1141            {
1142                cx.background_executor().timer(timeout).await;
1143            }
1144
1145            // If this task was cancelled before the throttle timeout expired,
1146            // do not perform a request.
1147            let mut is_cancelled = true;
1148            this.update(cx, |this, cx| {
1149                let project_state = this.get_or_init_zeta_project(&project, cx);
1150                if !project_state
1151                    .cancelled_predictions
1152                    .remove(&pending_prediction_id)
1153                {
1154                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1155                    is_cancelled = false;
1156                }
1157            })
1158            .ok();
1159            if is_cancelled {
1160                return None;
1161            }
1162
1163            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1164            let new_prediction_id = new_prediction_result
1165                .as_ref()
1166                .map(|(prediction, _)| prediction.id.clone());
1167
1168            // When a prediction completes, remove it from the pending list, and cancel
1169            // any pending predictions that were enqueued before it.
1170            this.update(cx, |this, cx| {
1171                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1172
1173                let is_cancelled = zeta_project
1174                    .cancelled_predictions
1175                    .remove(&pending_prediction_id);
1176
1177                let new_current_prediction = if !is_cancelled
1178                    && let Some((prediction_result, requested_by)) = new_prediction_result
1179                {
1180                    match prediction_result.prediction {
1181                        Ok(prediction) => {
1182                            let new_prediction = CurrentEditPrediction {
1183                                requested_by,
1184                                prediction,
1185                                was_shown: false,
1186                            };
1187
1188                            if let Some(current_prediction) =
1189                                zeta_project.current_prediction.as_ref()
1190                            {
1191                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1192                                {
1193                                    this.reject_current_prediction(
1194                                        EditPredictionRejectReason::Replaced,
1195                                        &project,
1196                                        cx,
1197                                    );
1198
1199                                    Some(new_prediction)
1200                                } else {
1201                                    this.reject_prediction(
1202                                        new_prediction.prediction.id,
1203                                        EditPredictionRejectReason::CurrentPreferred,
1204                                        false,
1205                                        cx,
1206                                    );
1207                                    None
1208                                }
1209                            } else {
1210                                Some(new_prediction)
1211                            }
1212                        }
1213                        Err(reject_reason) => {
1214                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1215                            None
1216                        }
1217                    }
1218                } else {
1219                    None
1220                };
1221
1222                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1223
1224                if let Some(new_prediction) = new_current_prediction {
1225                    zeta_project.current_prediction = Some(new_prediction);
1226                }
1227
1228                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1229                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1230                    if pending_prediction.id == pending_prediction_id {
1231                        pending_predictions.remove(ix);
1232                        for pending_prediction in pending_predictions.drain(0..ix) {
1233                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1234                        }
1235                        break;
1236                    }
1237                }
1238                this.get_or_init_zeta_project(&project, cx)
1239                    .pending_predictions = pending_predictions;
1240                cx.notify();
1241            })
1242            .ok();
1243
1244            new_prediction_id
1245        });
1246
1247        if zeta_project.pending_predictions.len() <= 1 {
1248            zeta_project.pending_predictions.push(PendingPrediction {
1249                id: pending_prediction_id,
1250                task,
1251            });
1252        } else if zeta_project.pending_predictions.len() == 2 {
1253            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1254            zeta_project.pending_predictions.push(PendingPrediction {
1255                id: pending_prediction_id,
1256                task,
1257            });
1258            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1259        }
1260    }
1261
1262    pub fn request_prediction(
1263        &mut self,
1264        project: &Entity<Project>,
1265        active_buffer: &Entity<Buffer>,
1266        position: language::Anchor,
1267        cx: &mut Context<Self>,
1268    ) -> Task<Result<Option<EditPredictionResult>>> {
1269        self.request_prediction_internal(
1270            project.clone(),
1271            active_buffer.clone(),
1272            position,
1273            cx.has_flag::<Zeta2FeatureFlag>(),
1274            cx,
1275        )
1276    }
1277
1278    fn request_prediction_internal(
1279        &mut self,
1280        project: Entity<Project>,
1281        active_buffer: Entity<Buffer>,
1282        position: language::Anchor,
1283        allow_jump: bool,
1284        cx: &mut Context<Self>,
1285    ) -> Task<Result<Option<EditPredictionResult>>> {
1286        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1287
1288        self.get_or_init_zeta_project(&project, cx);
1289        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1290        let events = zeta_project.events(cx);
1291        let has_events = !events.is_empty();
1292
1293        let snapshot = active_buffer.read(cx).snapshot();
1294        let cursor_point = position.to_point(&snapshot);
1295        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1296        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1297        let diagnostic_search_range =
1298            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1299
1300        let task = match self.edit_prediction_model {
1301            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1302                self,
1303                &project,
1304                &active_buffer,
1305                snapshot.clone(),
1306                position,
1307                events,
1308                cx,
1309            ),
1310            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1311                &project,
1312                &active_buffer,
1313                snapshot.clone(),
1314                position,
1315                events,
1316                cx,
1317            ),
1318            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1319                &project,
1320                &active_buffer,
1321                snapshot.clone(),
1322                position,
1323                events,
1324                &zeta_project.recent_paths,
1325                diagnostic_search_range.clone(),
1326                cx,
1327            ),
1328        };
1329
1330        cx.spawn(async move |this, cx| {
1331            let prediction = task.await?;
1332
1333            if prediction.is_none() && allow_jump {
1334                let cursor_point = position.to_point(&snapshot);
1335                if has_events
1336                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1337                        active_buffer.clone(),
1338                        &snapshot,
1339                        diagnostic_search_range,
1340                        cursor_point,
1341                        &project,
1342                        cx,
1343                    )
1344                    .await?
1345                {
1346                    return this
1347                        .update(cx, |this, cx| {
1348                            this.request_prediction_internal(
1349                                project,
1350                                jump_buffer,
1351                                jump_position,
1352                                false,
1353                                cx,
1354                            )
1355                        })?
1356                        .await;
1357                }
1358
1359                return anyhow::Ok(None);
1360            }
1361
1362            Ok(prediction)
1363        })
1364    }
1365
1366    async fn next_diagnostic_location(
1367        active_buffer: Entity<Buffer>,
1368        active_buffer_snapshot: &BufferSnapshot,
1369        active_buffer_diagnostic_search_range: Range<Point>,
1370        active_buffer_cursor_point: Point,
1371        project: &Entity<Project>,
1372        cx: &mut AsyncApp,
1373    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1374        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1375        let mut jump_location = active_buffer_snapshot
1376            .diagnostic_groups(None)
1377            .into_iter()
1378            .filter_map(|(_, group)| {
1379                let range = &group.entries[group.primary_ix]
1380                    .range
1381                    .to_point(&active_buffer_snapshot);
1382                if range.overlaps(&active_buffer_diagnostic_search_range) {
1383                    None
1384                } else {
1385                    Some(range.start)
1386                }
1387            })
1388            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1389            .map(|position| {
1390                (
1391                    active_buffer.clone(),
1392                    active_buffer_snapshot.anchor_before(position),
1393                )
1394            });
1395
1396        if jump_location.is_none() {
1397            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1398                let file = buffer.file()?;
1399
1400                Some(ProjectPath {
1401                    worktree_id: file.worktree_id(cx),
1402                    path: file.path().clone(),
1403                })
1404            })?;
1405
1406            let buffer_task = project.update(cx, |project, cx| {
1407                let (path, _, _) = project
1408                    .diagnostic_summaries(false, cx)
1409                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1410                    .max_by_key(|(path, _, _)| {
1411                        // find the buffer with errors that shares most parent directories
1412                        path.path
1413                            .components()
1414                            .zip(
1415                                active_buffer_path
1416                                    .as_ref()
1417                                    .map(|p| p.path.components())
1418                                    .unwrap_or_default(),
1419                            )
1420                            .take_while(|(a, b)| a == b)
1421                            .count()
1422                    })?;
1423
1424                Some(project.open_buffer(path, cx))
1425            })?;
1426
1427            if let Some(buffer_task) = buffer_task {
1428                let closest_buffer = buffer_task.await?;
1429
1430                jump_location = closest_buffer
1431                    .read_with(cx, |buffer, _cx| {
1432                        buffer
1433                            .buffer_diagnostics(None)
1434                            .into_iter()
1435                            .min_by_key(|entry| entry.diagnostic.severity)
1436                            .map(|entry| entry.range.start)
1437                    })?
1438                    .map(|position| (closest_buffer, position));
1439            }
1440        }
1441
1442        anyhow::Ok(jump_location)
1443    }
1444
1445    fn request_prediction_with_zeta2(
1446        &mut self,
1447        project: &Entity<Project>,
1448        active_buffer: &Entity<Buffer>,
1449        active_snapshot: BufferSnapshot,
1450        position: language::Anchor,
1451        events: Vec<Arc<Event>>,
1452        cx: &mut Context<Self>,
1453    ) -> Task<Result<Option<EditPredictionResult>>> {
1454        let project_state = self.projects.get(&project.entity_id());
1455
1456        let index_state = project_state.and_then(|state| {
1457            state
1458                .syntax_index
1459                .as_ref()
1460                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1461        });
1462        let options = self.options.clone();
1463        let buffer_snapshotted_at = Instant::now();
1464        let Some(excerpt_path) = active_snapshot
1465            .file()
1466            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1467        else {
1468            return Task::ready(Err(anyhow!("No file path for excerpt")));
1469        };
1470        let client = self.client.clone();
1471        let llm_token = self.llm_token.clone();
1472        let app_version = AppVersion::global(cx);
1473        let worktree_snapshots = project
1474            .read(cx)
1475            .worktrees(cx)
1476            .map(|worktree| worktree.read(cx).snapshot())
1477            .collect::<Vec<_>>();
1478        let debug_tx = self.debug_tx.clone();
1479
1480        let diagnostics = active_snapshot.diagnostic_sets().clone();
1481
1482        let file = active_buffer.read(cx).file();
1483        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1484            let mut path = f.worktree.read(cx).absolutize(&f.path);
1485            if path.pop() { Some(path) } else { None }
1486        });
1487
1488        // TODO data collection
1489        let can_collect_data = file
1490            .as_ref()
1491            .map_or(false, |file| self.can_collect_file(project, file, cx));
1492
1493        let empty_context_files = HashMap::default();
1494        let context_files = project_state
1495            .and_then(|project_state| project_state.context.as_ref())
1496            .unwrap_or(&empty_context_files);
1497
1498        #[cfg(feature = "eval-support")]
1499        let parsed_fut = futures::future::join_all(
1500            context_files
1501                .keys()
1502                .map(|buffer| buffer.read(cx).parsing_idle()),
1503        );
1504
1505        let mut included_files = context_files
1506            .iter()
1507            .filter_map(|(buffer_entity, ranges)| {
1508                let buffer = buffer_entity.read(cx);
1509                Some((
1510                    buffer_entity.clone(),
1511                    buffer.snapshot(),
1512                    buffer.file()?.full_path(cx).into(),
1513                    ranges.clone(),
1514                ))
1515            })
1516            .collect::<Vec<_>>();
1517
1518        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1519            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1520        });
1521
1522        #[cfg(feature = "eval-support")]
1523        let eval_cache = self.eval_cache.clone();
1524
1525        let request_task = cx.background_spawn({
1526            let active_buffer = active_buffer.clone();
1527            async move {
1528                #[cfg(feature = "eval-support")]
1529                parsed_fut.await;
1530
1531                let index_state = if let Some(index_state) = index_state {
1532                    Some(index_state.lock_owned().await)
1533                } else {
1534                    None
1535                };
1536
1537                let cursor_offset = position.to_offset(&active_snapshot);
1538                let cursor_point = cursor_offset.to_point(&active_snapshot);
1539
1540                let before_retrieval = Instant::now();
1541
1542                let (diagnostic_groups, diagnostic_groups_truncated) =
1543                    Self::gather_nearby_diagnostics(
1544                        cursor_offset,
1545                        &diagnostics,
1546                        &active_snapshot,
1547                        options.max_diagnostic_bytes,
1548                    );
1549
1550                let cloud_request = match options.context {
1551                    ContextMode::Agentic(context_options) => {
1552                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1553                            cursor_point,
1554                            &active_snapshot,
1555                            &context_options.excerpt,
1556                            index_state.as_deref(),
1557                        ) else {
1558                            return Ok((None, None));
1559                        };
1560
1561                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1562                            ..active_snapshot.anchor_before(excerpt.range.end);
1563
1564                        if let Some(buffer_ix) =
1565                            included_files.iter().position(|(_, snapshot, _, _)| {
1566                                snapshot.remote_id() == active_snapshot.remote_id()
1567                            })
1568                        {
1569                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1570                            ranges.push(excerpt_anchor_range);
1571                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1572                            let last_ix = included_files.len() - 1;
1573                            included_files.swap(buffer_ix, last_ix);
1574                        } else {
1575                            included_files.push((
1576                                active_buffer.clone(),
1577                                active_snapshot.clone(),
1578                                excerpt_path.clone(),
1579                                vec![excerpt_anchor_range],
1580                            ));
1581                        }
1582
1583                        let included_files = included_files
1584                            .iter()
1585                            .map(|(_, snapshot, path, ranges)| {
1586                                let ranges = ranges
1587                                    .iter()
1588                                    .map(|range| {
1589                                        let point_range = range.to_point(&snapshot);
1590                                        Line(point_range.start.row)..Line(point_range.end.row)
1591                                    })
1592                                    .collect::<Vec<_>>();
1593                                let excerpts = assemble_excerpts(&snapshot, ranges);
1594                                predict_edits_v3::IncludedFile {
1595                                    path: path.clone(),
1596                                    max_row: Line(snapshot.max_point().row),
1597                                    excerpts,
1598                                }
1599                            })
1600                            .collect::<Vec<_>>();
1601
1602                        predict_edits_v3::PredictEditsRequest {
1603                            excerpt_path,
1604                            excerpt: String::new(),
1605                            excerpt_line_range: Line(0)..Line(0),
1606                            excerpt_range: 0..0,
1607                            cursor_point: predict_edits_v3::Point {
1608                                line: predict_edits_v3::Line(cursor_point.row),
1609                                column: cursor_point.column,
1610                            },
1611                            included_files,
1612                            referenced_declarations: vec![],
1613                            events,
1614                            can_collect_data,
1615                            diagnostic_groups,
1616                            diagnostic_groups_truncated,
1617                            debug_info: debug_tx.is_some(),
1618                            prompt_max_bytes: Some(options.max_prompt_bytes),
1619                            prompt_format: options.prompt_format,
1620                            // TODO [zeta2]
1621                            signatures: vec![],
1622                            excerpt_parent: None,
1623                            git_info: None,
1624                        }
1625                    }
1626                    ContextMode::Syntax(context_options) => {
1627                        let Some(context) = EditPredictionContext::gather_context(
1628                            cursor_point,
1629                            &active_snapshot,
1630                            parent_abs_path.as_deref(),
1631                            &context_options,
1632                            index_state.as_deref(),
1633                        ) else {
1634                            return Ok((None, None));
1635                        };
1636
1637                        make_syntax_context_cloud_request(
1638                            excerpt_path,
1639                            context,
1640                            events,
1641                            can_collect_data,
1642                            diagnostic_groups,
1643                            diagnostic_groups_truncated,
1644                            None,
1645                            debug_tx.is_some(),
1646                            &worktree_snapshots,
1647                            index_state.as_deref(),
1648                            Some(options.max_prompt_bytes),
1649                            options.prompt_format,
1650                        )
1651                    }
1652                };
1653
1654                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1655
1656                let inputs = EditPredictionInputs {
1657                    included_files: cloud_request.included_files,
1658                    events: cloud_request.events,
1659                    cursor_point: cloud_request.cursor_point,
1660                    cursor_path: cloud_request.excerpt_path,
1661                };
1662
1663                let retrieval_time = Instant::now() - before_retrieval;
1664
1665                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1666                    let (response_tx, response_rx) = oneshot::channel();
1667
1668                    debug_tx
1669                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1670                            ZetaEditPredictionDebugInfo {
1671                                inputs: inputs.clone(),
1672                                retrieval_time,
1673                                buffer: active_buffer.downgrade(),
1674                                local_prompt: match prompt_result.as_ref() {
1675                                    Ok((prompt, _)) => Ok(prompt.clone()),
1676                                    Err(err) => Err(err.to_string()),
1677                                },
1678                                position,
1679                                response_rx,
1680                            },
1681                        ))
1682                        .ok();
1683                    Some(response_tx)
1684                } else {
1685                    None
1686                };
1687
1688                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1689                    if let Some(debug_response_tx) = debug_response_tx {
1690                        debug_response_tx
1691                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1692                            .ok();
1693                    }
1694                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1695                }
1696
1697                let (prompt, _) = prompt_result?;
1698                let generation_params =
1699                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1700                let request = open_ai::Request {
1701                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1702                    messages: vec![open_ai::RequestMessage::User {
1703                        content: open_ai::MessageContent::Plain(prompt),
1704                    }],
1705                    stream: false,
1706                    max_completion_tokens: None,
1707                    stop: generation_params.stop.unwrap_or_default(),
1708                    temperature: generation_params.temperature.unwrap_or(0.7),
1709                    tool_choice: None,
1710                    parallel_tool_calls: None,
1711                    tools: vec![],
1712                    prompt_cache_key: None,
1713                    reasoning_effort: None,
1714                };
1715
1716                log::trace!("Sending edit prediction request");
1717
1718                let before_request = Instant::now();
1719                let response = Self::send_raw_llm_request(
1720                    request,
1721                    client,
1722                    llm_token,
1723                    app_version,
1724                    #[cfg(feature = "eval-support")]
1725                    eval_cache,
1726                    #[cfg(feature = "eval-support")]
1727                    EvalCacheEntryKind::Prediction,
1728                )
1729                .await;
1730                let received_response_at = Instant::now();
1731                let request_time = received_response_at - before_request;
1732
1733                log::trace!("Got edit prediction response");
1734
1735                if let Some(debug_response_tx) = debug_response_tx {
1736                    debug_response_tx
1737                        .send((
1738                            response
1739                                .as_ref()
1740                                .map_err(|err| err.to_string())
1741                                .map(|response| response.0.clone()),
1742                            request_time,
1743                        ))
1744                        .ok();
1745                }
1746
1747                let (res, usage) = response?;
1748                let request_id = EditPredictionId(res.id.clone().into());
1749                let Some(mut output_text) = text_from_response(res) else {
1750                    return Ok((Some((request_id, None)), usage));
1751                };
1752
1753                if output_text.contains(CURSOR_MARKER) {
1754                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1755                    output_text = output_text.replace(CURSOR_MARKER, "");
1756                }
1757
1758                let get_buffer_from_context = |path: &Path| {
1759                    included_files
1760                        .iter()
1761                        .find_map(|(_, buffer, probe_path, ranges)| {
1762                            if probe_path.as_ref() == path {
1763                                Some((buffer, ranges.as_slice()))
1764                            } else {
1765                                None
1766                            }
1767                        })
1768                };
1769
1770                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1771                    PromptFormat::NumLinesUniDiff => {
1772                        // TODO: Implement parsing of multi-file diffs
1773                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1774                    }
1775                    PromptFormat::Minimal
1776                    | PromptFormat::MinimalQwen
1777                    | PromptFormat::SeedCoder1120 => {
1778                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1779                            let edits = vec![];
1780                            (&active_snapshot, edits)
1781                        } else {
1782                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1783                        }
1784                    }
1785                    PromptFormat::OldTextNewText => {
1786                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1787                            .await?
1788                    }
1789                    _ => {
1790                        bail!("unsupported prompt format {}", options.prompt_format)
1791                    }
1792                };
1793
1794                let edited_buffer = included_files
1795                    .iter()
1796                    .find_map(|(buffer, snapshot, _, _)| {
1797                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1798                            Some(buffer.clone())
1799                        } else {
1800                            None
1801                        }
1802                    })
1803                    .context("Failed to find buffer in included_buffers")?;
1804
1805                anyhow::Ok((
1806                    Some((
1807                        request_id,
1808                        Some((
1809                            inputs,
1810                            edited_buffer,
1811                            edited_buffer_snapshot.clone(),
1812                            edits,
1813                            received_response_at,
1814                        )),
1815                    )),
1816                    usage,
1817                ))
1818            }
1819        });
1820
1821        cx.spawn({
1822            async move |this, cx| {
1823                let Some((id, prediction)) =
1824                    Self::handle_api_response(&this, request_task.await, cx)?
1825                else {
1826                    return Ok(None);
1827                };
1828
1829                let Some((
1830                    inputs,
1831                    edited_buffer,
1832                    edited_buffer_snapshot,
1833                    edits,
1834                    received_response_at,
1835                )) = prediction
1836                else {
1837                    return Ok(Some(EditPredictionResult {
1838                        id,
1839                        prediction: Err(EditPredictionRejectReason::Empty),
1840                    }));
1841                };
1842
1843                // TODO telemetry: duration, etc
1844                Ok(Some(
1845                    EditPredictionResult::new(
1846                        id,
1847                        &edited_buffer,
1848                        &edited_buffer_snapshot,
1849                        edits.into(),
1850                        buffer_snapshotted_at,
1851                        received_response_at,
1852                        inputs,
1853                        cx,
1854                    )
1855                    .await,
1856                ))
1857            }
1858        })
1859    }
1860
1861    async fn send_raw_llm_request(
1862        request: open_ai::Request,
1863        client: Arc<Client>,
1864        llm_token: LlmApiToken,
1865        app_version: Version,
1866        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1867        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1868    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1869        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1870            http_client::Url::parse(&predict_edits_url)?
1871        } else {
1872            client
1873                .http_client()
1874                .build_zed_llm_url("/predict_edits/raw", &[])?
1875        };
1876
1877        #[cfg(feature = "eval-support")]
1878        let cache_key = if let Some(cache) = eval_cache {
1879            use collections::FxHasher;
1880            use std::hash::{Hash, Hasher};
1881
1882            let mut hasher = FxHasher::default();
1883            url.hash(&mut hasher);
1884            let request_str = serde_json::to_string_pretty(&request)?;
1885            request_str.hash(&mut hasher);
1886            let hash = hasher.finish();
1887
1888            let key = (eval_cache_kind, hash);
1889            if let Some(response_str) = cache.read(key) {
1890                return Ok((serde_json::from_str(&response_str)?, None));
1891            }
1892
1893            Some((cache, request_str, key))
1894        } else {
1895            None
1896        };
1897
1898        let (response, usage) = Self::send_api_request(
1899            |builder| {
1900                let req = builder
1901                    .uri(url.as_ref())
1902                    .body(serde_json::to_string(&request)?.into());
1903                Ok(req?)
1904            },
1905            client,
1906            llm_token,
1907            app_version,
1908        )
1909        .await?;
1910
1911        #[cfg(feature = "eval-support")]
1912        if let Some((cache, request, key)) = cache_key {
1913            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1914        }
1915
1916        Ok((response, usage))
1917    }
1918
1919    fn handle_api_response<T>(
1920        this: &WeakEntity<Self>,
1921        response: Result<(T, Option<EditPredictionUsage>)>,
1922        cx: &mut gpui::AsyncApp,
1923    ) -> Result<T> {
1924        match response {
1925            Ok((data, usage)) => {
1926                if let Some(usage) = usage {
1927                    this.update(cx, |this, cx| {
1928                        this.user_store.update(cx, |user_store, cx| {
1929                            user_store.update_edit_prediction_usage(usage, cx);
1930                        });
1931                    })
1932                    .ok();
1933                }
1934                Ok(data)
1935            }
1936            Err(err) => {
1937                if err.is::<ZedUpdateRequiredError>() {
1938                    cx.update(|cx| {
1939                        this.update(cx, |this, _cx| {
1940                            this.update_required = true;
1941                        })
1942                        .ok();
1943
1944                        let error_message: SharedString = err.to_string().into();
1945                        show_app_notification(
1946                            NotificationId::unique::<ZedUpdateRequiredError>(),
1947                            cx,
1948                            move |cx| {
1949                                cx.new(|cx| {
1950                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1951                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1952                                })
1953                            },
1954                        );
1955                    })
1956                    .ok();
1957                }
1958                Err(err)
1959            }
1960        }
1961    }
1962
1963    async fn send_api_request<Res>(
1964        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1965        client: Arc<Client>,
1966        llm_token: LlmApiToken,
1967        app_version: Version,
1968    ) -> Result<(Res, Option<EditPredictionUsage>)>
1969    where
1970        Res: DeserializeOwned,
1971    {
1972        let http_client = client.http_client();
1973        let mut token = llm_token.acquire(&client).await?;
1974        let mut did_retry = false;
1975
1976        loop {
1977            let request_builder = http_client::Request::builder().method(Method::POST);
1978
1979            let request = build(
1980                request_builder
1981                    .header("Content-Type", "application/json")
1982                    .header("Authorization", format!("Bearer {}", token))
1983                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1984            )?;
1985
1986            let mut response = http_client.send(request).await?;
1987
1988            if let Some(minimum_required_version) = response
1989                .headers()
1990                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1991                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1992            {
1993                anyhow::ensure!(
1994                    app_version >= minimum_required_version,
1995                    ZedUpdateRequiredError {
1996                        minimum_version: minimum_required_version
1997                    }
1998                );
1999            }
2000
2001            if response.status().is_success() {
2002                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2003
2004                let mut body = Vec::new();
2005                response.body_mut().read_to_end(&mut body).await?;
2006                return Ok((serde_json::from_slice(&body)?, usage));
2007            } else if !did_retry
2008                && response
2009                    .headers()
2010                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2011                    .is_some()
2012            {
2013                did_retry = true;
2014                token = llm_token.refresh(&client).await?;
2015            } else {
2016                let mut body = String::new();
2017                response.body_mut().read_to_string(&mut body).await?;
2018                anyhow::bail!(
2019                    "Request failed with status: {:?}\nBody: {}",
2020                    response.status(),
2021                    body
2022                );
2023            }
2024        }
2025    }
2026
2027    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2028    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2029
2030    // Refresh the related excerpts when the user just beguns editing after
2031    // an idle period, and after they pause editing.
2032    fn refresh_context_if_needed(
2033        &mut self,
2034        project: &Entity<Project>,
2035        buffer: &Entity<language::Buffer>,
2036        cursor_position: language::Anchor,
2037        cx: &mut Context<Self>,
2038    ) {
2039        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2040            return;
2041        }
2042
2043        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2044            return;
2045        };
2046
2047        let now = Instant::now();
2048        let was_idle = zeta_project
2049            .refresh_context_timestamp
2050            .map_or(true, |timestamp| {
2051                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2052            });
2053        zeta_project.refresh_context_timestamp = Some(now);
2054        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2055            let buffer = buffer.clone();
2056            let project = project.clone();
2057            async move |this, cx| {
2058                if was_idle {
2059                    log::debug!("refetching edit prediction context after idle");
2060                } else {
2061                    cx.background_executor()
2062                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2063                        .await;
2064                    log::debug!("refetching edit prediction context after pause");
2065                }
2066                this.update(cx, |this, cx| {
2067                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2068
2069                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2070                        zeta_project.refresh_context_task = Some(task.log_err());
2071                    };
2072                })
2073                .ok()
2074            }
2075        }));
2076    }
2077
2078    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2079    // and avoid spawning more than one concurrent task.
2080    pub fn refresh_context(
2081        &mut self,
2082        project: Entity<Project>,
2083        buffer: Entity<language::Buffer>,
2084        cursor_position: language::Anchor,
2085        cx: &mut Context<Self>,
2086    ) -> Task<Result<()>> {
2087        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2088            return Task::ready(anyhow::Ok(()));
2089        };
2090
2091        let ContextMode::Agentic(options) = &self.options().context else {
2092            return Task::ready(anyhow::Ok(()));
2093        };
2094
2095        let snapshot = buffer.read(cx).snapshot();
2096        let cursor_point = cursor_position.to_point(&snapshot);
2097        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2098            cursor_point,
2099            &snapshot,
2100            &options.excerpt,
2101            None,
2102        ) else {
2103            return Task::ready(Ok(()));
2104        };
2105
2106        let app_version = AppVersion::global(cx);
2107        let client = self.client.clone();
2108        let llm_token = self.llm_token.clone();
2109        let debug_tx = self.debug_tx.clone();
2110        let current_file_path: Arc<Path> = snapshot
2111            .file()
2112            .map(|f| f.full_path(cx).into())
2113            .unwrap_or_else(|| Path::new("untitled").into());
2114
2115        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2116            predict_edits_v3::PlanContextRetrievalRequest {
2117                excerpt: cursor_excerpt.text(&snapshot).body,
2118                excerpt_path: current_file_path,
2119                excerpt_line_range: cursor_excerpt.line_range,
2120                cursor_file_max_row: Line(snapshot.max_point().row),
2121                events: zeta_project.events(cx),
2122            },
2123        ) {
2124            Ok(prompt) => prompt,
2125            Err(err) => {
2126                return Task::ready(Err(err));
2127            }
2128        };
2129
2130        if let Some(debug_tx) = &debug_tx {
2131            debug_tx
2132                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2133                    ZetaContextRetrievalStartedDebugInfo {
2134                        project: project.clone(),
2135                        timestamp: Instant::now(),
2136                        search_prompt: prompt.clone(),
2137                    },
2138                ))
2139                .ok();
2140        }
2141
2142        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2143            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2144                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2145            );
2146
2147            let description = schema
2148                .get("description")
2149                .and_then(|description| description.as_str())
2150                .unwrap()
2151                .to_string();
2152
2153            (schema.into(), description)
2154        });
2155
2156        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2157
2158        let request = open_ai::Request {
2159            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2160            messages: vec![open_ai::RequestMessage::User {
2161                content: open_ai::MessageContent::Plain(prompt),
2162            }],
2163            stream: false,
2164            max_completion_tokens: None,
2165            stop: Default::default(),
2166            temperature: 0.7,
2167            tool_choice: None,
2168            parallel_tool_calls: None,
2169            tools: vec![open_ai::ToolDefinition::Function {
2170                function: FunctionDefinition {
2171                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2172                    description: Some(tool_description),
2173                    parameters: Some(tool_schema),
2174                },
2175            }],
2176            prompt_cache_key: None,
2177            reasoning_effort: None,
2178        };
2179
2180        #[cfg(feature = "eval-support")]
2181        let eval_cache = self.eval_cache.clone();
2182
2183        cx.spawn(async move |this, cx| {
2184            log::trace!("Sending search planning request");
2185            let response = Self::send_raw_llm_request(
2186                request,
2187                client,
2188                llm_token,
2189                app_version,
2190                #[cfg(feature = "eval-support")]
2191                eval_cache.clone(),
2192                #[cfg(feature = "eval-support")]
2193                EvalCacheEntryKind::Context,
2194            )
2195            .await;
2196            let mut response = Self::handle_api_response(&this, response, cx)?;
2197            log::trace!("Got search planning response");
2198
2199            let choice = response
2200                .choices
2201                .pop()
2202                .context("No choices in retrieval response")?;
2203            let open_ai::RequestMessage::Assistant {
2204                content: _,
2205                tool_calls,
2206            } = choice.message
2207            else {
2208                anyhow::bail!("Retrieval response didn't include an assistant message");
2209            };
2210
2211            let mut queries: Vec<SearchToolQuery> = Vec::new();
2212            for tool_call in tool_calls {
2213                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2214                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2215                    log::warn!(
2216                        "Context retrieval response tried to call an unknown tool: {}",
2217                        function.name
2218                    );
2219
2220                    continue;
2221                }
2222
2223                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2224                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2225                queries.extend(input.queries);
2226            }
2227
2228            if let Some(debug_tx) = &debug_tx {
2229                debug_tx
2230                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2231                        ZetaSearchQueryDebugInfo {
2232                            project: project.clone(),
2233                            timestamp: Instant::now(),
2234                            search_queries: queries.clone(),
2235                        },
2236                    ))
2237                    .ok();
2238            }
2239
2240            log::trace!("Running retrieval search: {queries:#?}");
2241
2242            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2243                queries,
2244                project.clone(),
2245                #[cfg(feature = "eval-support")]
2246                eval_cache,
2247                cx,
2248            )
2249            .await;
2250
2251            log::trace!("Search queries executed");
2252
2253            if let Some(debug_tx) = &debug_tx {
2254                debug_tx
2255                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2256                        ZetaContextRetrievalDebugInfo {
2257                            project: project.clone(),
2258                            timestamp: Instant::now(),
2259                        },
2260                    ))
2261                    .ok();
2262            }
2263
2264            this.update(cx, |this, _cx| {
2265                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2266                    return Ok(());
2267                };
2268                zeta_project.refresh_context_task.take();
2269                if let Some(debug_tx) = &this.debug_tx {
2270                    debug_tx
2271                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2272                            ZetaContextRetrievalDebugInfo {
2273                                project,
2274                                timestamp: Instant::now(),
2275                            },
2276                        ))
2277                        .ok();
2278                }
2279                match related_excerpts_result {
2280                    Ok(excerpts) => {
2281                        zeta_project.context = Some(excerpts);
2282                        Ok(())
2283                    }
2284                    Err(error) => Err(error),
2285                }
2286            })?
2287        })
2288    }
2289
2290    pub fn set_context(
2291        &mut self,
2292        project: Entity<Project>,
2293        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2294    ) {
2295        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2296            zeta_project.context = Some(context);
2297        }
2298    }
2299
2300    fn gather_nearby_diagnostics(
2301        cursor_offset: usize,
2302        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2303        snapshot: &BufferSnapshot,
2304        max_diagnostics_bytes: usize,
2305    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2306        // TODO: Could make this more efficient
2307        let mut diagnostic_groups = Vec::new();
2308        for (language_server_id, diagnostics) in diagnostic_sets {
2309            let mut groups = Vec::new();
2310            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2311            diagnostic_groups.extend(
2312                groups
2313                    .into_iter()
2314                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2315            );
2316        }
2317
2318        // sort by proximity to cursor
2319        diagnostic_groups.sort_by_key(|group| {
2320            let range = &group.entries[group.primary_ix].range;
2321            if range.start >= cursor_offset {
2322                range.start - cursor_offset
2323            } else if cursor_offset >= range.end {
2324                cursor_offset - range.end
2325            } else {
2326                (cursor_offset - range.start).min(range.end - cursor_offset)
2327            }
2328        });
2329
2330        let mut results = Vec::new();
2331        let mut diagnostic_groups_truncated = false;
2332        let mut diagnostics_byte_count = 0;
2333        for group in diagnostic_groups {
2334            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2335            diagnostics_byte_count += raw_value.get().len();
2336            if diagnostics_byte_count > max_diagnostics_bytes {
2337                diagnostic_groups_truncated = true;
2338                break;
2339            }
2340            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2341        }
2342
2343        (results, diagnostic_groups_truncated)
2344    }
2345
2346    // TODO: Dedupe with similar code in request_prediction?
2347    pub fn cloud_request_for_zeta_cli(
2348        &mut self,
2349        project: &Entity<Project>,
2350        buffer: &Entity<Buffer>,
2351        position: language::Anchor,
2352        cx: &mut Context<Self>,
2353    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2354        let project_state = self.projects.get(&project.entity_id());
2355
2356        let index_state = project_state.and_then(|state| {
2357            state
2358                .syntax_index
2359                .as_ref()
2360                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2361        });
2362        let options = self.options.clone();
2363        let snapshot = buffer.read(cx).snapshot();
2364        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2365            return Task::ready(Err(anyhow!("No file path for excerpt")));
2366        };
2367        let worktree_snapshots = project
2368            .read(cx)
2369            .worktrees(cx)
2370            .map(|worktree| worktree.read(cx).snapshot())
2371            .collect::<Vec<_>>();
2372
2373        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2374            let mut path = f.worktree.read(cx).absolutize(&f.path);
2375            if path.pop() { Some(path) } else { None }
2376        });
2377
2378        cx.background_spawn(async move {
2379            let index_state = if let Some(index_state) = index_state {
2380                Some(index_state.lock_owned().await)
2381            } else {
2382                None
2383            };
2384
2385            let cursor_point = position.to_point(&snapshot);
2386
2387            let debug_info = true;
2388            EditPredictionContext::gather_context(
2389                cursor_point,
2390                &snapshot,
2391                parent_abs_path.as_deref(),
2392                match &options.context {
2393                    ContextMode::Agentic(_) => {
2394                        // TODO
2395                        panic!("Llm mode not supported in zeta cli yet");
2396                    }
2397                    ContextMode::Syntax(edit_prediction_context_options) => {
2398                        edit_prediction_context_options
2399                    }
2400                },
2401                index_state.as_deref(),
2402            )
2403            .context("Failed to select excerpt")
2404            .map(|context| {
2405                make_syntax_context_cloud_request(
2406                    excerpt_path.into(),
2407                    context,
2408                    // TODO pass everything
2409                    Vec::new(),
2410                    false,
2411                    Vec::new(),
2412                    false,
2413                    None,
2414                    debug_info,
2415                    &worktree_snapshots,
2416                    index_state.as_deref(),
2417                    Some(options.max_prompt_bytes),
2418                    options.prompt_format,
2419                )
2420            })
2421        })
2422    }
2423
2424    pub fn wait_for_initial_indexing(
2425        &mut self,
2426        project: &Entity<Project>,
2427        cx: &mut Context<Self>,
2428    ) -> Task<Result<()>> {
2429        let zeta_project = self.get_or_init_zeta_project(project, cx);
2430        if let Some(syntax_index) = &zeta_project.syntax_index {
2431            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2432        } else {
2433            Task::ready(Ok(()))
2434        }
2435    }
2436
2437    fn is_file_open_source(
2438        &self,
2439        project: &Entity<Project>,
2440        file: &Arc<dyn File>,
2441        cx: &App,
2442    ) -> bool {
2443        if !file.is_local() || file.is_private() {
2444            return false;
2445        }
2446        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2447            return false;
2448        };
2449        zeta_project
2450            .license_detection_watchers
2451            .get(&file.worktree_id(cx))
2452            .as_ref()
2453            .is_some_and(|watcher| watcher.is_project_open_source())
2454    }
2455
2456    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2457        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2458    }
2459
2460    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2461        if !self.data_collection_choice.is_enabled() {
2462            return false;
2463        }
2464        events.iter().all(|event| {
2465            matches!(
2466                event.as_ref(),
2467                Event::BufferChange {
2468                    in_open_source_repo: true,
2469                    ..
2470                }
2471            )
2472        })
2473    }
2474
2475    fn load_data_collection_choice() -> DataCollectionChoice {
2476        let choice = KEY_VALUE_STORE
2477            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2478            .log_err()
2479            .flatten();
2480
2481        match choice.as_deref() {
2482            Some("true") => DataCollectionChoice::Enabled,
2483            Some("false") => DataCollectionChoice::Disabled,
2484            Some(_) => {
2485                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2486                DataCollectionChoice::NotAnswered
2487            }
2488            None => DataCollectionChoice::NotAnswered,
2489        }
2490    }
2491
2492    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2493        self.shown_predictions.iter()
2494    }
2495
2496    pub fn shown_completions_len(&self) -> usize {
2497        self.shown_predictions.len()
2498    }
2499
2500    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2501        self.rated_predictions.contains(id)
2502    }
2503
2504    pub fn rate_prediction(
2505        &mut self,
2506        prediction: &EditPrediction,
2507        rating: EditPredictionRating,
2508        feedback: String,
2509        cx: &mut Context<Self>,
2510    ) {
2511        self.rated_predictions.insert(prediction.id.clone());
2512        telemetry::event!(
2513            "Edit Prediction Rated",
2514            rating,
2515            inputs = prediction.inputs,
2516            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2517            feedback
2518        );
2519        self.client.telemetry().flush_events().detach();
2520        cx.notify();
2521    }
2522}
2523
2524pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2525    let choice = res.choices.pop()?;
2526    let output_text = match choice.message {
2527        open_ai::RequestMessage::Assistant {
2528            content: Some(open_ai::MessageContent::Plain(content)),
2529            ..
2530        } => content,
2531        open_ai::RequestMessage::Assistant {
2532            content: Some(open_ai::MessageContent::Multipart(mut content)),
2533            ..
2534        } => {
2535            if content.is_empty() {
2536                log::error!("No output from Baseten completion response");
2537                return None;
2538            }
2539
2540            match content.remove(0) {
2541                open_ai::MessagePart::Text { text } => text,
2542                open_ai::MessagePart::Image { .. } => {
2543                    log::error!("Expected text, got an image");
2544                    return None;
2545                }
2546            }
2547        }
2548        _ => {
2549            log::error!("Invalid response message: {:?}", choice.message);
2550            return None;
2551        }
2552    };
2553    Some(output_text)
2554}
2555
2556#[derive(Error, Debug)]
2557#[error(
2558    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2559)]
2560pub struct ZedUpdateRequiredError {
2561    minimum_version: Version,
2562}
2563
2564fn make_syntax_context_cloud_request(
2565    excerpt_path: Arc<Path>,
2566    context: EditPredictionContext,
2567    events: Vec<Arc<predict_edits_v3::Event>>,
2568    can_collect_data: bool,
2569    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2570    diagnostic_groups_truncated: bool,
2571    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2572    debug_info: bool,
2573    worktrees: &Vec<worktree::Snapshot>,
2574    index_state: Option<&SyntaxIndexState>,
2575    prompt_max_bytes: Option<usize>,
2576    prompt_format: PromptFormat,
2577) -> predict_edits_v3::PredictEditsRequest {
2578    let mut signatures = Vec::new();
2579    let mut declaration_to_signature_index = HashMap::default();
2580    let mut referenced_declarations = Vec::new();
2581
2582    for snippet in context.declarations {
2583        let project_entry_id = snippet.declaration.project_entry_id();
2584        let Some(path) = worktrees.iter().find_map(|worktree| {
2585            worktree.entry_for_id(project_entry_id).map(|entry| {
2586                let mut full_path = RelPathBuf::new();
2587                full_path.push(worktree.root_name());
2588                full_path.push(&entry.path);
2589                full_path
2590            })
2591        }) else {
2592            continue;
2593        };
2594
2595        let parent_index = index_state.and_then(|index_state| {
2596            snippet.declaration.parent().and_then(|parent| {
2597                add_signature(
2598                    parent,
2599                    &mut declaration_to_signature_index,
2600                    &mut signatures,
2601                    index_state,
2602                )
2603            })
2604        });
2605
2606        let (text, text_is_truncated) = snippet.declaration.item_text();
2607        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2608            path: path.as_std_path().into(),
2609            text: text.into(),
2610            range: snippet.declaration.item_line_range(),
2611            text_is_truncated,
2612            signature_range: snippet.declaration.signature_range_in_item_text(),
2613            parent_index,
2614            signature_score: snippet.score(DeclarationStyle::Signature),
2615            declaration_score: snippet.score(DeclarationStyle::Declaration),
2616            score_components: snippet.components,
2617        });
2618    }
2619
2620    let excerpt_parent = index_state.and_then(|index_state| {
2621        context
2622            .excerpt
2623            .parent_declarations
2624            .last()
2625            .and_then(|(parent, _)| {
2626                add_signature(
2627                    *parent,
2628                    &mut declaration_to_signature_index,
2629                    &mut signatures,
2630                    index_state,
2631                )
2632            })
2633    });
2634
2635    predict_edits_v3::PredictEditsRequest {
2636        excerpt_path,
2637        excerpt: context.excerpt_text.body,
2638        excerpt_line_range: context.excerpt.line_range,
2639        excerpt_range: context.excerpt.range,
2640        cursor_point: predict_edits_v3::Point {
2641            line: predict_edits_v3::Line(context.cursor_point.row),
2642            column: context.cursor_point.column,
2643        },
2644        referenced_declarations,
2645        included_files: vec![],
2646        signatures,
2647        excerpt_parent,
2648        events,
2649        can_collect_data,
2650        diagnostic_groups,
2651        diagnostic_groups_truncated,
2652        git_info,
2653        debug_info,
2654        prompt_max_bytes,
2655        prompt_format,
2656    }
2657}
2658
2659fn add_signature(
2660    declaration_id: DeclarationId,
2661    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2662    signatures: &mut Vec<Signature>,
2663    index: &SyntaxIndexState,
2664) -> Option<usize> {
2665    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2666        return Some(*signature_index);
2667    }
2668    let Some(parent_declaration) = index.declaration(declaration_id) else {
2669        log::error!("bug: missing parent declaration");
2670        return None;
2671    };
2672    let parent_index = parent_declaration.parent().and_then(|parent| {
2673        add_signature(parent, declaration_to_signature_index, signatures, index)
2674    });
2675    let (text, text_is_truncated) = parent_declaration.signature_text();
2676    let signature_index = signatures.len();
2677    signatures.push(Signature {
2678        text: text.into(),
2679        text_is_truncated,
2680        parent_index,
2681        range: parent_declaration.signature_line_range(),
2682    });
2683    declaration_to_signature_index.insert(declaration_id, signature_index);
2684    Some(signature_index)
2685}
2686
2687#[cfg(feature = "eval-support")]
2688pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2689
2690#[cfg(feature = "eval-support")]
2691#[derive(Debug, Clone, Copy, PartialEq)]
2692pub enum EvalCacheEntryKind {
2693    Context,
2694    Search,
2695    Prediction,
2696}
2697
2698#[cfg(feature = "eval-support")]
2699impl std::fmt::Display for EvalCacheEntryKind {
2700    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2701        match self {
2702            EvalCacheEntryKind::Search => write!(f, "search"),
2703            EvalCacheEntryKind::Context => write!(f, "context"),
2704            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2705        }
2706    }
2707}
2708
2709#[cfg(feature = "eval-support")]
2710pub trait EvalCache: Send + Sync {
2711    fn read(&self, key: EvalCacheKey) -> Option<String>;
2712    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2713}
2714
2715#[derive(Debug, Clone, Copy)]
2716pub enum DataCollectionChoice {
2717    NotAnswered,
2718    Enabled,
2719    Disabled,
2720}
2721
2722impl DataCollectionChoice {
2723    pub fn is_enabled(self) -> bool {
2724        match self {
2725            Self::Enabled => true,
2726            Self::NotAnswered | Self::Disabled => false,
2727        }
2728    }
2729
2730    pub fn is_answered(self) -> bool {
2731        match self {
2732            Self::Enabled | Self::Disabled => true,
2733            Self::NotAnswered => false,
2734        }
2735    }
2736
2737    #[must_use]
2738    pub fn toggle(&self) -> DataCollectionChoice {
2739        match self {
2740            Self::Enabled => Self::Disabled,
2741            Self::Disabled => Self::Enabled,
2742            Self::NotAnswered => Self::Enabled,
2743        }
2744    }
2745}
2746
2747impl From<bool> for DataCollectionChoice {
2748    fn from(value: bool) -> Self {
2749        match value {
2750            true => DataCollectionChoice::Enabled,
2751            false => DataCollectionChoice::Disabled,
2752        }
2753    }
2754}
2755
2756struct ZedPredictUpsell;
2757
2758impl Dismissable for ZedPredictUpsell {
2759    const KEY: &'static str = "dismissed-edit-predict-upsell";
2760
2761    fn dismissed() -> bool {
2762        // To make this backwards compatible with older versions of Zed, we
2763        // check if the user has seen the previous Edit Prediction Onboarding
2764        // before, by checking the data collection choice which was written to
2765        // the database once the user clicked on "Accept and Enable"
2766        if KEY_VALUE_STORE
2767            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2768            .log_err()
2769            .is_some_and(|s| s.is_some())
2770        {
2771            return true;
2772        }
2773
2774        KEY_VALUE_STORE
2775            .read_kvp(Self::KEY)
2776            .log_err()
2777            .is_some_and(|s| s.is_some())
2778    }
2779}
2780
2781pub fn should_show_upsell_modal() -> bool {
2782    !ZedPredictUpsell::dismissed()
2783}
2784
2785pub fn init(cx: &mut App) {
2786    feature_gate_predict_edits_actions(cx);
2787
2788    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2789        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2790            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2791                RatePredictionsModal::toggle(workspace, window, cx);
2792            }
2793        });
2794
2795        workspace.register_action(
2796            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2797                ZedPredictModal::toggle(
2798                    workspace,
2799                    workspace.user_store().clone(),
2800                    workspace.client().clone(),
2801                    window,
2802                    cx,
2803                )
2804            },
2805        );
2806
2807        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2808            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2809                settings
2810                    .project
2811                    .all_languages
2812                    .features
2813                    .get_or_insert_default()
2814                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2815            });
2816        });
2817    })
2818    .detach();
2819}
2820
2821fn feature_gate_predict_edits_actions(cx: &mut App) {
2822    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2823    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2824    let zeta_all_action_types = [
2825        TypeId::of::<RateCompletions>(),
2826        TypeId::of::<ResetOnboarding>(),
2827        zed_actions::OpenZedPredictOnboarding.type_id(),
2828        TypeId::of::<ClearHistory>(),
2829        TypeId::of::<ThumbsUpActivePrediction>(),
2830        TypeId::of::<ThumbsDownActivePrediction>(),
2831        TypeId::of::<NextEdit>(),
2832        TypeId::of::<PreviousEdit>(),
2833    ];
2834
2835    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2836        filter.hide_action_types(&rate_completion_action_types);
2837        filter.hide_action_types(&reset_onboarding_action_types);
2838        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2839    });
2840
2841    cx.observe_global::<SettingsStore>(move |cx| {
2842        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2843        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2844
2845        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2846            if is_ai_disabled {
2847                filter.hide_action_types(&zeta_all_action_types);
2848            } else if has_feature_flag {
2849                filter.show_action_types(&rate_completion_action_types);
2850            } else {
2851                filter.hide_action_types(&rate_completion_action_types);
2852            }
2853        });
2854    })
2855    .detach();
2856
2857    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2858        if !DisableAiSettings::get_global(cx).disable_ai {
2859            if is_enabled {
2860                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2861                    filter.show_action_types(&rate_completion_action_types);
2862                });
2863            } else {
2864                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2865                    filter.hide_action_types(&rate_completion_action_types);
2866                });
2867            }
2868        }
2869    })
2870    .detach();
2871}
2872
2873#[cfg(test)]
2874mod tests {
2875    use std::{path::Path, sync::Arc};
2876
2877    use client::UserStore;
2878    use clock::FakeSystemClock;
2879    use cloud_llm_client::{
2880        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2881    };
2882    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2883    use futures::{
2884        AsyncReadExt, StreamExt,
2885        channel::{mpsc, oneshot},
2886    };
2887    use gpui::{
2888        Entity, TestAppContext,
2889        http_client::{FakeHttpClient, Response},
2890        prelude::*,
2891    };
2892    use indoc::indoc;
2893    use language::OffsetRangeExt as _;
2894    use open_ai::Usage;
2895    use pretty_assertions::{assert_eq, assert_matches};
2896    use project::{FakeFs, Project};
2897    use serde_json::json;
2898    use settings::SettingsStore;
2899    use util::path;
2900    use uuid::Uuid;
2901
2902    use crate::{BufferEditPrediction, Zeta};
2903
2904    #[gpui::test]
2905    async fn test_current_state(cx: &mut TestAppContext) {
2906        let (zeta, mut requests) = init_test(cx);
2907        let fs = FakeFs::new(cx.executor());
2908        fs.insert_tree(
2909            "/root",
2910            json!({
2911                "1.txt": "Hello!\nHow\nBye\n",
2912                "2.txt": "Hola!\nComo\nAdios\n"
2913            }),
2914        )
2915        .await;
2916        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2917
2918        zeta.update(cx, |zeta, cx| {
2919            zeta.register_project(&project, cx);
2920        });
2921
2922        let buffer1 = project
2923            .update(cx, |project, cx| {
2924                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2925                project.open_buffer(path, cx)
2926            })
2927            .await
2928            .unwrap();
2929        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2930        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2931
2932        // Prediction for current file
2933
2934        zeta.update(cx, |zeta, cx| {
2935            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2936        });
2937        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2938
2939        respond_tx
2940            .send(model_response(indoc! {r"
2941                --- a/root/1.txt
2942                +++ b/root/1.txt
2943                @@ ... @@
2944                 Hello!
2945                -How
2946                +How are you?
2947                 Bye
2948            "}))
2949            .unwrap();
2950
2951        cx.run_until_parked();
2952
2953        zeta.read_with(cx, |zeta, cx| {
2954            let prediction = zeta
2955                .current_prediction_for_buffer(&buffer1, &project, cx)
2956                .unwrap();
2957            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2958        });
2959
2960        // Context refresh
2961        let refresh_task = zeta.update(cx, |zeta, cx| {
2962            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2963        });
2964        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2965        respond_tx
2966            .send(open_ai::Response {
2967                id: Uuid::new_v4().to_string(),
2968                object: "response".into(),
2969                created: 0,
2970                model: "model".into(),
2971                choices: vec![open_ai::Choice {
2972                    index: 0,
2973                    message: open_ai::RequestMessage::Assistant {
2974                        content: None,
2975                        tool_calls: vec![open_ai::ToolCall {
2976                            id: "search".into(),
2977                            content: open_ai::ToolCallContent::Function {
2978                                function: open_ai::FunctionContent {
2979                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2980                                        .to_string(),
2981                                    arguments: serde_json::to_string(&SearchToolInput {
2982                                        queries: Box::new([SearchToolQuery {
2983                                            glob: "root/2.txt".to_string(),
2984                                            syntax_node: vec![],
2985                                            content: Some(".".into()),
2986                                        }]),
2987                                    })
2988                                    .unwrap(),
2989                                },
2990                            },
2991                        }],
2992                    },
2993                    finish_reason: None,
2994                }],
2995                usage: Usage {
2996                    prompt_tokens: 0,
2997                    completion_tokens: 0,
2998                    total_tokens: 0,
2999                },
3000            })
3001            .unwrap();
3002        refresh_task.await.unwrap();
3003
3004        zeta.update(cx, |zeta, cx| {
3005            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3006        });
3007
3008        // Prediction for another file
3009        zeta.update(cx, |zeta, cx| {
3010            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3011        });
3012        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3013        respond_tx
3014            .send(model_response(indoc! {r#"
3015                --- a/root/2.txt
3016                +++ b/root/2.txt
3017                 Hola!
3018                -Como
3019                +Como estas?
3020                 Adios
3021            "#}))
3022            .unwrap();
3023        cx.run_until_parked();
3024
3025        zeta.read_with(cx, |zeta, cx| {
3026            let prediction = zeta
3027                .current_prediction_for_buffer(&buffer1, &project, cx)
3028                .unwrap();
3029            assert_matches!(
3030                prediction,
3031                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3032            );
3033        });
3034
3035        let buffer2 = project
3036            .update(cx, |project, cx| {
3037                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3038                project.open_buffer(path, cx)
3039            })
3040            .await
3041            .unwrap();
3042
3043        zeta.read_with(cx, |zeta, cx| {
3044            let prediction = zeta
3045                .current_prediction_for_buffer(&buffer2, &project, cx)
3046                .unwrap();
3047            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3048        });
3049    }
3050
3051    #[gpui::test]
3052    async fn test_simple_request(cx: &mut TestAppContext) {
3053        let (zeta, mut requests) = init_test(cx);
3054        let fs = FakeFs::new(cx.executor());
3055        fs.insert_tree(
3056            "/root",
3057            json!({
3058                "foo.md":  "Hello!\nHow\nBye\n"
3059            }),
3060        )
3061        .await;
3062        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3063
3064        let buffer = project
3065            .update(cx, |project, cx| {
3066                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3067                project.open_buffer(path, cx)
3068            })
3069            .await
3070            .unwrap();
3071        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3072        let position = snapshot.anchor_before(language::Point::new(1, 3));
3073
3074        let prediction_task = zeta.update(cx, |zeta, cx| {
3075            zeta.request_prediction(&project, &buffer, position, cx)
3076        });
3077
3078        let (_, respond_tx) = requests.predict.next().await.unwrap();
3079
3080        // TODO Put back when we have a structured request again
3081        // assert_eq!(
3082        //     request.excerpt_path.as_ref(),
3083        //     Path::new(path!("root/foo.md"))
3084        // );
3085        // assert_eq!(
3086        //     request.cursor_point,
3087        //     Point {
3088        //         line: Line(1),
3089        //         column: 3
3090        //     }
3091        // );
3092
3093        respond_tx
3094            .send(model_response(indoc! { r"
3095                --- a/root/foo.md
3096                +++ b/root/foo.md
3097                @@ ... @@
3098                 Hello!
3099                -How
3100                +How are you?
3101                 Bye
3102            "}))
3103            .unwrap();
3104
3105        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3106
3107        assert_eq!(prediction.edits.len(), 1);
3108        assert_eq!(
3109            prediction.edits[0].0.to_point(&snapshot).start,
3110            language::Point::new(1, 3)
3111        );
3112        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3113    }
3114
3115    #[gpui::test]
3116    async fn test_request_events(cx: &mut TestAppContext) {
3117        let (zeta, mut requests) = init_test(cx);
3118        let fs = FakeFs::new(cx.executor());
3119        fs.insert_tree(
3120            "/root",
3121            json!({
3122                "foo.md": "Hello!\n\nBye\n"
3123            }),
3124        )
3125        .await;
3126        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3127
3128        let buffer = project
3129            .update(cx, |project, cx| {
3130                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3131                project.open_buffer(path, cx)
3132            })
3133            .await
3134            .unwrap();
3135
3136        zeta.update(cx, |zeta, cx| {
3137            zeta.register_buffer(&buffer, &project, cx);
3138        });
3139
3140        buffer.update(cx, |buffer, cx| {
3141            buffer.edit(vec![(7..7, "How")], None, cx);
3142        });
3143
3144        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3145        let position = snapshot.anchor_before(language::Point::new(1, 3));
3146
3147        let prediction_task = zeta.update(cx, |zeta, cx| {
3148            zeta.request_prediction(&project, &buffer, position, cx)
3149        });
3150
3151        let (request, respond_tx) = requests.predict.next().await.unwrap();
3152
3153        let prompt = prompt_from_request(&request);
3154        assert!(
3155            prompt.contains(indoc! {"
3156            --- a/root/foo.md
3157            +++ b/root/foo.md
3158            @@ -1,3 +1,3 @@
3159             Hello!
3160            -
3161            +How
3162             Bye
3163        "}),
3164            "{prompt}"
3165        );
3166
3167        respond_tx
3168            .send(model_response(indoc! {r#"
3169                --- a/root/foo.md
3170                +++ b/root/foo.md
3171                @@ ... @@
3172                 Hello!
3173                -How
3174                +How are you?
3175                 Bye
3176            "#}))
3177            .unwrap();
3178
3179        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3180
3181        assert_eq!(prediction.edits.len(), 1);
3182        assert_eq!(
3183            prediction.edits[0].0.to_point(&snapshot).start,
3184            language::Point::new(1, 3)
3185        );
3186        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3187    }
3188
3189    #[gpui::test]
3190    async fn test_empty_prediction(cx: &mut TestAppContext) {
3191        let (zeta, mut requests) = init_test(cx);
3192        let fs = FakeFs::new(cx.executor());
3193        fs.insert_tree(
3194            "/root",
3195            json!({
3196                "foo.md":  "Hello!\nHow\nBye\n"
3197            }),
3198        )
3199        .await;
3200        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3201
3202        let buffer = project
3203            .update(cx, |project, cx| {
3204                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3205                project.open_buffer(path, cx)
3206            })
3207            .await
3208            .unwrap();
3209        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3210        let position = snapshot.anchor_before(language::Point::new(1, 3));
3211
3212        zeta.update(cx, |zeta, cx| {
3213            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3214        });
3215
3216        const NO_OP_DIFF: &str = indoc! { r"
3217            --- a/root/foo.md
3218            +++ b/root/foo.md
3219            @@ ... @@
3220             Hello!
3221            -How
3222            +How
3223             Bye
3224        "};
3225
3226        let (_, respond_tx) = requests.predict.next().await.unwrap();
3227        let response = model_response(NO_OP_DIFF);
3228        let id = response.id.clone();
3229        respond_tx.send(response).unwrap();
3230
3231        cx.run_until_parked();
3232
3233        zeta.read_with(cx, |zeta, cx| {
3234            assert!(
3235                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3236                    .is_none()
3237            );
3238        });
3239
3240        // prediction is reported as rejected
3241        let (reject_request, _) = requests.reject.next().await.unwrap();
3242
3243        assert_eq!(
3244            &reject_request.rejections,
3245            &[EditPredictionRejection {
3246                request_id: id,
3247                reason: EditPredictionRejectReason::Empty,
3248                was_shown: false
3249            }]
3250        );
3251    }
3252
3253    #[gpui::test]
3254    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3255        let (zeta, mut requests) = init_test(cx);
3256        let fs = FakeFs::new(cx.executor());
3257        fs.insert_tree(
3258            "/root",
3259            json!({
3260                "foo.md":  "Hello!\nHow\nBye\n"
3261            }),
3262        )
3263        .await;
3264        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3265
3266        let buffer = project
3267            .update(cx, |project, cx| {
3268                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3269                project.open_buffer(path, cx)
3270            })
3271            .await
3272            .unwrap();
3273        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3274        let position = snapshot.anchor_before(language::Point::new(1, 3));
3275
3276        zeta.update(cx, |zeta, cx| {
3277            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3278        });
3279
3280        let (_, respond_tx) = requests.predict.next().await.unwrap();
3281
3282        buffer.update(cx, |buffer, cx| {
3283            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3284        });
3285
3286        let response = model_response(SIMPLE_DIFF);
3287        let id = response.id.clone();
3288        respond_tx.send(response).unwrap();
3289
3290        cx.run_until_parked();
3291
3292        zeta.read_with(cx, |zeta, cx| {
3293            assert!(
3294                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3295                    .is_none()
3296            );
3297        });
3298
3299        // prediction is reported as rejected
3300        let (reject_request, _) = requests.reject.next().await.unwrap();
3301
3302        assert_eq!(
3303            &reject_request.rejections,
3304            &[EditPredictionRejection {
3305                request_id: id,
3306                reason: EditPredictionRejectReason::InterpolatedEmpty,
3307                was_shown: false
3308            }]
3309        );
3310    }
3311
3312    const SIMPLE_DIFF: &str = indoc! { r"
3313        --- a/root/foo.md
3314        +++ b/root/foo.md
3315        @@ ... @@
3316         Hello!
3317        -How
3318        +How are you?
3319         Bye
3320    "};
3321
3322    #[gpui::test]
3323    async fn test_replace_current(cx: &mut TestAppContext) {
3324        let (zeta, mut requests) = init_test(cx);
3325        let fs = FakeFs::new(cx.executor());
3326        fs.insert_tree(
3327            "/root",
3328            json!({
3329                "foo.md":  "Hello!\nHow\nBye\n"
3330            }),
3331        )
3332        .await;
3333        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3334
3335        let buffer = project
3336            .update(cx, |project, cx| {
3337                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3338                project.open_buffer(path, cx)
3339            })
3340            .await
3341            .unwrap();
3342        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3343        let position = snapshot.anchor_before(language::Point::new(1, 3));
3344
3345        zeta.update(cx, |zeta, cx| {
3346            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3347        });
3348
3349        let (_, respond_tx) = requests.predict.next().await.unwrap();
3350        let first_response = model_response(SIMPLE_DIFF);
3351        let first_id = first_response.id.clone();
3352        respond_tx.send(first_response).unwrap();
3353
3354        cx.run_until_parked();
3355
3356        zeta.read_with(cx, |zeta, cx| {
3357            assert_eq!(
3358                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3359                    .unwrap()
3360                    .id
3361                    .0,
3362                first_id
3363            );
3364        });
3365
3366        // a second request is triggered
3367        zeta.update(cx, |zeta, cx| {
3368            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3369        });
3370
3371        let (_, respond_tx) = requests.predict.next().await.unwrap();
3372        let second_response = model_response(SIMPLE_DIFF);
3373        let second_id = second_response.id.clone();
3374        respond_tx.send(second_response).unwrap();
3375
3376        cx.run_until_parked();
3377
3378        zeta.read_with(cx, |zeta, cx| {
3379            // second replaces first
3380            assert_eq!(
3381                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3382                    .unwrap()
3383                    .id
3384                    .0,
3385                second_id
3386            );
3387        });
3388
3389        // first is reported as replaced
3390        let (reject_request, _) = requests.reject.next().await.unwrap();
3391
3392        assert_eq!(
3393            &reject_request.rejections,
3394            &[EditPredictionRejection {
3395                request_id: first_id,
3396                reason: EditPredictionRejectReason::Replaced,
3397                was_shown: false
3398            }]
3399        );
3400    }
3401
3402    #[gpui::test]
3403    async fn test_current_preferred(cx: &mut TestAppContext) {
3404        let (zeta, mut requests) = init_test(cx);
3405        let fs = FakeFs::new(cx.executor());
3406        fs.insert_tree(
3407            "/root",
3408            json!({
3409                "foo.md":  "Hello!\nHow\nBye\n"
3410            }),
3411        )
3412        .await;
3413        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3414
3415        let buffer = project
3416            .update(cx, |project, cx| {
3417                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3418                project.open_buffer(path, cx)
3419            })
3420            .await
3421            .unwrap();
3422        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3423        let position = snapshot.anchor_before(language::Point::new(1, 3));
3424
3425        zeta.update(cx, |zeta, cx| {
3426            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3427        });
3428
3429        let (_, respond_tx) = requests.predict.next().await.unwrap();
3430        let first_response = model_response(SIMPLE_DIFF);
3431        let first_id = first_response.id.clone();
3432        respond_tx.send(first_response).unwrap();
3433
3434        cx.run_until_parked();
3435
3436        zeta.read_with(cx, |zeta, cx| {
3437            assert_eq!(
3438                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3439                    .unwrap()
3440                    .id
3441                    .0,
3442                first_id
3443            );
3444        });
3445
3446        // a second request is triggered
3447        zeta.update(cx, |zeta, cx| {
3448            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3449        });
3450
3451        let (_, respond_tx) = requests.predict.next().await.unwrap();
3452        // worse than current prediction
3453        let second_response = model_response(indoc! { r"
3454            --- a/root/foo.md
3455            +++ b/root/foo.md
3456            @@ ... @@
3457             Hello!
3458            -How
3459            +How are
3460             Bye
3461        "});
3462        let second_id = second_response.id.clone();
3463        respond_tx.send(second_response).unwrap();
3464
3465        cx.run_until_parked();
3466
3467        zeta.read_with(cx, |zeta, cx| {
3468            // first is preferred over second
3469            assert_eq!(
3470                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3471                    .unwrap()
3472                    .id
3473                    .0,
3474                first_id
3475            );
3476        });
3477
3478        // second is reported as rejected
3479        let (reject_request, _) = requests.reject.next().await.unwrap();
3480
3481        assert_eq!(
3482            &reject_request.rejections,
3483            &[EditPredictionRejection {
3484                request_id: second_id,
3485                reason: EditPredictionRejectReason::CurrentPreferred,
3486                was_shown: false
3487            }]
3488        );
3489    }
3490
3491    #[gpui::test]
3492    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3493        let (zeta, mut requests) = init_test(cx);
3494        let fs = FakeFs::new(cx.executor());
3495        fs.insert_tree(
3496            "/root",
3497            json!({
3498                "foo.md":  "Hello!\nHow\nBye\n"
3499            }),
3500        )
3501        .await;
3502        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3503
3504        let buffer = project
3505            .update(cx, |project, cx| {
3506                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3507                project.open_buffer(path, cx)
3508            })
3509            .await
3510            .unwrap();
3511        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3512        let position = snapshot.anchor_before(language::Point::new(1, 3));
3513
3514        zeta.update(cx, |zeta, cx| {
3515            // start two refresh tasks
3516            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3517
3518            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3519        });
3520
3521        let (_, respond_first) = requests.predict.next().await.unwrap();
3522        let (_, respond_second) = requests.predict.next().await.unwrap();
3523
3524        // wait for throttle
3525        cx.run_until_parked();
3526
3527        // second responds first
3528        let second_response = model_response(SIMPLE_DIFF);
3529        let second_id = second_response.id.clone();
3530        respond_second.send(second_response).unwrap();
3531
3532        cx.run_until_parked();
3533
3534        zeta.read_with(cx, |zeta, cx| {
3535            // current prediction is second
3536            assert_eq!(
3537                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3538                    .unwrap()
3539                    .id
3540                    .0,
3541                second_id
3542            );
3543        });
3544
3545        let first_response = model_response(SIMPLE_DIFF);
3546        let first_id = first_response.id.clone();
3547        respond_first.send(first_response).unwrap();
3548
3549        cx.run_until_parked();
3550
3551        zeta.read_with(cx, |zeta, cx| {
3552            // current prediction is still second, since first was cancelled
3553            assert_eq!(
3554                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3555                    .unwrap()
3556                    .id
3557                    .0,
3558                second_id
3559            );
3560        });
3561
3562        // first is reported as rejected
3563        let (reject_request, _) = requests.reject.next().await.unwrap();
3564
3565        cx.run_until_parked();
3566
3567        assert_eq!(
3568            &reject_request.rejections,
3569            &[EditPredictionRejection {
3570                request_id: first_id,
3571                reason: EditPredictionRejectReason::Canceled,
3572                was_shown: false
3573            }]
3574        );
3575    }
3576
3577    #[gpui::test]
3578    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3579        let (zeta, mut requests) = init_test(cx);
3580        let fs = FakeFs::new(cx.executor());
3581        fs.insert_tree(
3582            "/root",
3583            json!({
3584                "foo.md":  "Hello!\nHow\nBye\n"
3585            }),
3586        )
3587        .await;
3588        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3589
3590        let buffer = project
3591            .update(cx, |project, cx| {
3592                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3593                project.open_buffer(path, cx)
3594            })
3595            .await
3596            .unwrap();
3597        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3598        let position = snapshot.anchor_before(language::Point::new(1, 3));
3599
3600        zeta.update(cx, |zeta, cx| {
3601            // start two refresh tasks
3602            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3603            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3604        });
3605
3606        // wait for throttle, so requests are sent
3607        cx.run_until_parked();
3608
3609        let (_, respond_first) = requests.predict.next().await.unwrap();
3610        let (_, respond_second) = requests.predict.next().await.unwrap();
3611
3612        zeta.update(cx, |zeta, cx| {
3613            // start a third request
3614            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3615
3616            // 2 are pending, so 2nd is cancelled
3617            assert_eq!(
3618                zeta.get_or_init_zeta_project(&project, cx)
3619                    .cancelled_predictions
3620                    .iter()
3621                    .copied()
3622                    .collect::<Vec<_>>(),
3623                [1]
3624            );
3625        });
3626
3627        // wait for throttle
3628        cx.run_until_parked();
3629
3630        let (_, respond_third) = requests.predict.next().await.unwrap();
3631
3632        let first_response = model_response(SIMPLE_DIFF);
3633        let first_id = first_response.id.clone();
3634        respond_first.send(first_response).unwrap();
3635
3636        cx.run_until_parked();
3637
3638        zeta.read_with(cx, |zeta, cx| {
3639            // current prediction is first
3640            assert_eq!(
3641                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3642                    .unwrap()
3643                    .id
3644                    .0,
3645                first_id
3646            );
3647        });
3648
3649        let cancelled_response = model_response(SIMPLE_DIFF);
3650        let cancelled_id = cancelled_response.id.clone();
3651        respond_second.send(cancelled_response).unwrap();
3652
3653        cx.run_until_parked();
3654
3655        zeta.read_with(cx, |zeta, cx| {
3656            // current prediction is still first, since second was cancelled
3657            assert_eq!(
3658                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3659                    .unwrap()
3660                    .id
3661                    .0,
3662                first_id
3663            );
3664        });
3665
3666        let third_response = model_response(SIMPLE_DIFF);
3667        let third_response_id = third_response.id.clone();
3668        respond_third.send(third_response).unwrap();
3669
3670        cx.run_until_parked();
3671
3672        zeta.read_with(cx, |zeta, cx| {
3673            // third completes and replaces first
3674            assert_eq!(
3675                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3676                    .unwrap()
3677                    .id
3678                    .0,
3679                third_response_id
3680            );
3681        });
3682
3683        // second is reported as rejected
3684        let (reject_request, _) = requests.reject.next().await.unwrap();
3685
3686        cx.run_until_parked();
3687
3688        assert_eq!(
3689            &reject_request.rejections,
3690            &[
3691                EditPredictionRejection {
3692                    request_id: cancelled_id,
3693                    reason: EditPredictionRejectReason::Canceled,
3694                    was_shown: false
3695                },
3696                EditPredictionRejection {
3697                    request_id: first_id,
3698                    reason: EditPredictionRejectReason::Replaced,
3699                    was_shown: false
3700                }
3701            ]
3702        );
3703    }
3704
3705    // Skipped until we start including diagnostics in prompt
3706    // #[gpui::test]
3707    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3708    //     let (zeta, mut req_rx) = init_test(cx);
3709    //     let fs = FakeFs::new(cx.executor());
3710    //     fs.insert_tree(
3711    //         "/root",
3712    //         json!({
3713    //             "foo.md": "Hello!\nBye"
3714    //         }),
3715    //     )
3716    //     .await;
3717    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3718
3719    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3720    //     let diagnostic = lsp::Diagnostic {
3721    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3722    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3723    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3724    //         ..Default::default()
3725    //     };
3726
3727    //     project.update(cx, |project, cx| {
3728    //         project.lsp_store().update(cx, |lsp_store, cx| {
3729    //             // Create some diagnostics
3730    //             lsp_store
3731    //                 .update_diagnostics(
3732    //                     LanguageServerId(0),
3733    //                     lsp::PublishDiagnosticsParams {
3734    //                         uri: path_to_buffer_uri.clone(),
3735    //                         diagnostics: vec![diagnostic],
3736    //                         version: None,
3737    //                     },
3738    //                     None,
3739    //                     language::DiagnosticSourceKind::Pushed,
3740    //                     &[],
3741    //                     cx,
3742    //                 )
3743    //                 .unwrap();
3744    //         });
3745    //     });
3746
3747    //     let buffer = project
3748    //         .update(cx, |project, cx| {
3749    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3750    //             project.open_buffer(path, cx)
3751    //         })
3752    //         .await
3753    //         .unwrap();
3754
3755    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3756    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3757
3758    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3759    //         zeta.request_prediction(&project, &buffer, position, cx)
3760    //     });
3761
3762    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3763
3764    //     assert_eq!(request.diagnostic_groups.len(), 1);
3765    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3766    //         .unwrap();
3767    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3768    //     assert_eq!(
3769    //         value,
3770    //         json!({
3771    //             "entries": [{
3772    //                 "range": {
3773    //                     "start": 8,
3774    //                     "end": 10
3775    //                 },
3776    //                 "diagnostic": {
3777    //                     "source": null,
3778    //                     "code": null,
3779    //                     "code_description": null,
3780    //                     "severity": 1,
3781    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3782    //                     "markdown": null,
3783    //                     "group_id": 0,
3784    //                     "is_primary": true,
3785    //                     "is_disk_based": false,
3786    //                     "is_unnecessary": false,
3787    //                     "source_kind": "Pushed",
3788    //                     "data": null,
3789    //                     "underline": true
3790    //                 }
3791    //             }],
3792    //             "primary_ix": 0
3793    //         })
3794    //     );
3795    // }
3796
3797    fn model_response(text: &str) -> open_ai::Response {
3798        open_ai::Response {
3799            id: Uuid::new_v4().to_string(),
3800            object: "response".into(),
3801            created: 0,
3802            model: "model".into(),
3803            choices: vec![open_ai::Choice {
3804                index: 0,
3805                message: open_ai::RequestMessage::Assistant {
3806                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3807                    tool_calls: vec![],
3808                },
3809                finish_reason: None,
3810            }],
3811            usage: Usage {
3812                prompt_tokens: 0,
3813                completion_tokens: 0,
3814                total_tokens: 0,
3815            },
3816        }
3817    }
3818
3819    fn prompt_from_request(request: &open_ai::Request) -> &str {
3820        assert_eq!(request.messages.len(), 1);
3821        let open_ai::RequestMessage::User {
3822            content: open_ai::MessageContent::Plain(content),
3823            ..
3824        } = &request.messages[0]
3825        else {
3826            panic!(
3827                "Request does not have single user message of type Plain. {:#?}",
3828                request
3829            );
3830        };
3831        content
3832    }
3833
3834    struct RequestChannels {
3835        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3836        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3837    }
3838
3839    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3840        cx.update(move |cx| {
3841            let settings_store = SettingsStore::test(cx);
3842            cx.set_global(settings_store);
3843            zlog::init_test();
3844
3845            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3846            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3847
3848            let http_client = FakeHttpClient::create({
3849                move |req| {
3850                    let uri = req.uri().path().to_string();
3851                    let mut body = req.into_body();
3852                    let predict_req_tx = predict_req_tx.clone();
3853                    let reject_req_tx = reject_req_tx.clone();
3854                    async move {
3855                        let resp = match uri.as_str() {
3856                            "/client/llm_tokens" => serde_json::to_string(&json!({
3857                                "token": "test"
3858                            }))
3859                            .unwrap(),
3860                            "/predict_edits/raw" => {
3861                                let mut buf = Vec::new();
3862                                body.read_to_end(&mut buf).await.ok();
3863                                let req = serde_json::from_slice(&buf).unwrap();
3864
3865                                let (res_tx, res_rx) = oneshot::channel();
3866                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3867                                serde_json::to_string(&res_rx.await?).unwrap()
3868                            }
3869                            "/predict_edits/reject" => {
3870                                let mut buf = Vec::new();
3871                                body.read_to_end(&mut buf).await.ok();
3872                                let req = serde_json::from_slice(&buf).unwrap();
3873
3874                                let (res_tx, res_rx) = oneshot::channel();
3875                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3876                                serde_json::to_string(&res_rx.await?).unwrap()
3877                            }
3878                            _ => {
3879                                panic!("Unexpected path: {}", uri)
3880                            }
3881                        };
3882
3883                        Ok(Response::builder().body(resp.into()).unwrap())
3884                    }
3885                }
3886            });
3887
3888            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3889            client.cloud_client().set_credentials(1, "test".into());
3890
3891            language_model::init(client.clone(), cx);
3892
3893            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3894            let zeta = Zeta::global(&client, &user_store, cx);
3895
3896            (
3897                zeta,
3898                RequestChannels {
3899                    predict: predict_req_rx,
3900                    reject: reject_req_rx,
3901                },
3902            )
3903        })
3904    }
3905}