zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  13use collections::{HashMap, HashSet};
  14use command_palette_hooks::CommandPaletteFilter;
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{
  17    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  18    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  19    SyntaxIndex, SyntaxIndexState,
  20};
  21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  22use futures::channel::{mpsc, oneshot};
  23use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::{
  30    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  31};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, RefreshLlmTokenListener};
  34use open_ai::FunctionDefinition;
  35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  40use std::any::{Any as _, TypeId};
  41use std::collections::{VecDeque, hash_map};
  42use telemetry_events::EditPredictionRating;
  43use workspace::Workspace;
  44
  45use std::ops::Range;
  46use std::path::Path;
  47use std::rc::Rc;
  48use std::str::FromStr as _;
  49use std::sync::{Arc, LazyLock};
  50use std::time::{Duration, Instant};
  51use std::{env, mem};
  52use thiserror::Error;
  53use util::rel_path::RelPathBuf;
  54use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  56
  57pub mod assemble_excerpts;
  58mod license_detection;
  59mod onboarding_modal;
  60mod prediction;
  61mod provider;
  62mod rate_prediction_modal;
  63pub mod retrieval_search;
  64pub mod sweep_ai;
  65pub mod udiff;
  66mod xml_edits;
  67pub mod zeta1;
  68
  69#[cfg(test)]
  70mod zeta_tests;
  71
  72use crate::assemble_excerpts::assemble_excerpts;
  73use crate::license_detection::LicenseDetectionWatcher;
  74use crate::onboarding_modal::ZedPredictModal;
  75pub use crate::prediction::EditPrediction;
  76pub use crate::prediction::EditPredictionId;
  77pub use crate::prediction::EditPredictionInputs;
  78use crate::prediction::EditPredictionResult;
  79use crate::rate_prediction_modal::{
  80    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  81    ThumbsUpActivePrediction,
  82};
  83pub use crate::sweep_ai::SweepAi;
  84use crate::zeta1::request_prediction_with_zeta1;
  85pub use provider::ZetaEditPredictionProvider;
  86
  87actions!(
  88    edit_prediction,
  89    [
  90        /// Resets the edit prediction onboarding state.
  91        ResetOnboarding,
  92        /// Opens the rate completions modal.
  93        RateCompletions,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 103
 104pub struct SweepFeatureFlag;
 105
 106impl FeatureFlag for SweepFeatureFlag {
 107    const NAME: &str = "sweep-ai";
 108}
 109pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 110    max_bytes: 512,
 111    min_bytes: 128,
 112    target_before_cursor_over_total_bytes: 0.5,
 113};
 114
 115pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 116    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 117
 118pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 119    excerpt: DEFAULT_EXCERPT_OPTIONS,
 120};
 121
 122pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 123    EditPredictionContextOptions {
 124        use_imports: true,
 125        max_retrieved_declarations: 0,
 126        excerpt: DEFAULT_EXCERPT_OPTIONS,
 127        score: EditPredictionScoreOptions {
 128            omit_excerpt_overlaps: true,
 129        },
 130    };
 131
 132pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 133    context: DEFAULT_CONTEXT_OPTIONS,
 134    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 135    max_diagnostic_bytes: 2048,
 136    prompt_format: PromptFormat::DEFAULT,
 137    file_indexing_parallelism: 1,
 138    buffer_change_grouping_interval: Duration::from_secs(1),
 139};
 140
 141static USE_OLLAMA: LazyLock<bool> =
 142    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 143static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 144    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 145        "qwen3-coder:30b".to_string()
 146    } else {
 147        "yqvev8r3".to_string()
 148    })
 149});
 150static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 151    match env::var("ZED_ZETA2_MODEL").as_deref() {
 152        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 153        Ok(model) => model,
 154        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 155        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 156    }
 157    .to_string()
 158});
 159static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 160    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 161        if *USE_OLLAMA {
 162            Some("http://localhost:11434/v1/chat/completions".into())
 163        } else {
 164            None
 165        }
 166    })
 167});
 168
 169pub struct Zeta2FeatureFlag;
 170
 171impl FeatureFlag for Zeta2FeatureFlag {
 172    const NAME: &'static str = "zeta2";
 173
 174    fn enabled_for_staff() -> bool {
 175        true
 176    }
 177}
 178
 179#[derive(Clone)]
 180struct ZetaGlobal(Entity<Zeta>);
 181
 182impl Global for ZetaGlobal {}
 183
 184pub struct Zeta {
 185    client: Arc<Client>,
 186    user_store: Entity<UserStore>,
 187    llm_token: LlmApiToken,
 188    _llm_token_subscription: Subscription,
 189    projects: HashMap<EntityId, ZetaProject>,
 190    options: ZetaOptions,
 191    update_required: bool,
 192    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 193    #[cfg(feature = "eval-support")]
 194    eval_cache: Option<Arc<dyn EvalCache>>,
 195    edit_prediction_model: ZetaEditPredictionModel,
 196    pub sweep_ai: SweepAi,
 197    data_collection_choice: DataCollectionChoice,
 198    rejected_predictions: Vec<EditPredictionRejection>,
 199    reject_predictions_tx: mpsc::UnboundedSender<()>,
 200    reject_predictions_debounce_task: Option<Task<()>>,
 201    shown_predictions: VecDeque<EditPrediction>,
 202    rated_predictions: HashSet<EditPredictionId>,
 203}
 204
 205#[derive(Copy, Clone, Default, PartialEq, Eq)]
 206pub enum ZetaEditPredictionModel {
 207    #[default]
 208    Zeta1,
 209    Zeta2,
 210    Sweep,
 211}
 212
 213#[derive(Debug, Clone, PartialEq)]
 214pub struct ZetaOptions {
 215    pub context: ContextMode,
 216    pub max_prompt_bytes: usize,
 217    pub max_diagnostic_bytes: usize,
 218    pub prompt_format: predict_edits_v3::PromptFormat,
 219    pub file_indexing_parallelism: usize,
 220    pub buffer_change_grouping_interval: Duration,
 221}
 222
 223#[derive(Debug, Clone, PartialEq)]
 224pub enum ContextMode {
 225    Agentic(AgenticContextOptions),
 226    Syntax(EditPredictionContextOptions),
 227}
 228
 229#[derive(Debug, Clone, PartialEq)]
 230pub struct AgenticContextOptions {
 231    pub excerpt: EditPredictionExcerptOptions,
 232}
 233
 234impl ContextMode {
 235    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 236        match self {
 237            ContextMode::Agentic(options) => &options.excerpt,
 238            ContextMode::Syntax(options) => &options.excerpt,
 239        }
 240    }
 241}
 242
 243#[derive(Debug)]
 244pub enum ZetaDebugInfo {
 245    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 246    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 247    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 248    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 249    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 250}
 251
 252#[derive(Debug)]
 253pub struct ZetaContextRetrievalStartedDebugInfo {
 254    pub project: Entity<Project>,
 255    pub timestamp: Instant,
 256    pub search_prompt: String,
 257}
 258
 259#[derive(Debug)]
 260pub struct ZetaContextRetrievalDebugInfo {
 261    pub project: Entity<Project>,
 262    pub timestamp: Instant,
 263}
 264
 265#[derive(Debug)]
 266pub struct ZetaEditPredictionDebugInfo {
 267    pub inputs: EditPredictionInputs,
 268    pub retrieval_time: Duration,
 269    pub buffer: WeakEntity<Buffer>,
 270    pub position: language::Anchor,
 271    pub local_prompt: Result<String, String>,
 272    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 273}
 274
 275#[derive(Debug)]
 276pub struct ZetaSearchQueryDebugInfo {
 277    pub project: Entity<Project>,
 278    pub timestamp: Instant,
 279    pub search_queries: Vec<SearchToolQuery>,
 280}
 281
 282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 283
 284struct ZetaProject {
 285    syntax_index: Option<Entity<SyntaxIndex>>,
 286    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 287    last_event: Option<LastEvent>,
 288    recent_paths: VecDeque<ProjectPath>,
 289    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 290    current_prediction: Option<CurrentEditPrediction>,
 291    next_pending_prediction_id: usize,
 292    pending_predictions: ArrayVec<PendingPrediction, 2>,
 293    last_prediction_refresh: Option<(EntityId, Instant)>,
 294    cancelled_predictions: HashSet<usize>,
 295    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 296    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 297    refresh_context_debounce_task: Option<Task<Option<()>>>,
 298    refresh_context_timestamp: Option<Instant>,
 299    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 300    _subscription: gpui::Subscription,
 301}
 302
 303impl ZetaProject {
 304    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 305        self.events
 306            .iter()
 307            .cloned()
 308            .chain(
 309                self.last_event
 310                    .as_ref()
 311                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 312            )
 313            .collect()
 314    }
 315
 316    fn cancel_pending_prediction(
 317        &mut self,
 318        pending_prediction: PendingPrediction,
 319        cx: &mut Context<Zeta>,
 320    ) {
 321        self.cancelled_predictions.insert(pending_prediction.id);
 322
 323        cx.spawn(async move |this, cx| {
 324            let Some(prediction_id) = pending_prediction.task.await else {
 325                return;
 326            };
 327
 328            this.update(cx, |this, cx| {
 329                this.reject_prediction(
 330                    prediction_id,
 331                    EditPredictionRejectReason::Canceled,
 332                    false,
 333                    cx,
 334                );
 335            })
 336            .ok();
 337        })
 338        .detach()
 339    }
 340}
 341
 342#[derive(Debug, Clone)]
 343struct CurrentEditPrediction {
 344    pub requested_by: PredictionRequestedBy,
 345    pub prediction: EditPrediction,
 346    pub was_shown: bool,
 347}
 348
 349impl CurrentEditPrediction {
 350    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 351        let Some(new_edits) = self
 352            .prediction
 353            .interpolate(&self.prediction.buffer.read(cx))
 354        else {
 355            return false;
 356        };
 357
 358        if self.prediction.buffer != old_prediction.prediction.buffer {
 359            return true;
 360        }
 361
 362        let Some(old_edits) = old_prediction
 363            .prediction
 364            .interpolate(&old_prediction.prediction.buffer.read(cx))
 365        else {
 366            return true;
 367        };
 368
 369        let requested_by_buffer_id = self.requested_by.buffer_id();
 370
 371        // This reduces the occurrence of UI thrash from replacing edits
 372        //
 373        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 374        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 375            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 376            && old_edits.len() == 1
 377            && new_edits.len() == 1
 378        {
 379            let (old_range, old_text) = &old_edits[0];
 380            let (new_range, new_text) = &new_edits[0];
 381            new_range == old_range && new_text.starts_with(old_text.as_ref())
 382        } else {
 383            true
 384        }
 385    }
 386}
 387
 388#[derive(Debug, Clone)]
 389enum PredictionRequestedBy {
 390    DiagnosticsUpdate,
 391    Buffer(EntityId),
 392}
 393
 394impl PredictionRequestedBy {
 395    pub fn buffer_id(&self) -> Option<EntityId> {
 396        match self {
 397            PredictionRequestedBy::DiagnosticsUpdate => None,
 398            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 399        }
 400    }
 401}
 402
 403#[derive(Debug)]
 404struct PendingPrediction {
 405    id: usize,
 406    task: Task<Option<EditPredictionId>>,
 407}
 408
 409/// A prediction from the perspective of a buffer.
 410#[derive(Debug)]
 411enum BufferEditPrediction<'a> {
 412    Local { prediction: &'a EditPrediction },
 413    Jump { prediction: &'a EditPrediction },
 414}
 415
 416#[cfg(test)]
 417impl std::ops::Deref for BufferEditPrediction<'_> {
 418    type Target = EditPrediction;
 419
 420    fn deref(&self) -> &Self::Target {
 421        match self {
 422            BufferEditPrediction::Local { prediction } => prediction,
 423            BufferEditPrediction::Jump { prediction } => prediction,
 424        }
 425    }
 426}
 427
 428struct RegisteredBuffer {
 429    snapshot: BufferSnapshot,
 430    _subscriptions: [gpui::Subscription; 2],
 431}
 432
 433struct LastEvent {
 434    old_snapshot: BufferSnapshot,
 435    new_snapshot: BufferSnapshot,
 436    end_edit_anchor: Option<Anchor>,
 437}
 438
 439impl LastEvent {
 440    pub fn finalize(
 441        &self,
 442        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 443        cx: &App,
 444    ) -> Option<Arc<predict_edits_v3::Event>> {
 445        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 446        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 447
 448        let file = self.new_snapshot.file();
 449        let old_file = self.old_snapshot.file();
 450
 451        let in_open_source_repo = [file, old_file].iter().all(|file| {
 452            file.is_some_and(|file| {
 453                license_detection_watchers
 454                    .get(&file.worktree_id(cx))
 455                    .is_some_and(|watcher| watcher.is_project_open_source())
 456            })
 457        });
 458
 459        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 460
 461        if path == old_path && diff.is_empty() {
 462            None
 463        } else {
 464            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 465                old_path,
 466                path,
 467                diff,
 468                in_open_source_repo,
 469                // TODO: Actually detect if this edit was predicted or not
 470                predicted: false,
 471            }))
 472        }
 473    }
 474}
 475
 476fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 477    if let Some(file) = snapshot.file() {
 478        file.full_path(cx).into()
 479    } else {
 480        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 481    }
 482}
 483
 484impl Zeta {
 485    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 486        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 487    }
 488
 489    pub fn global(
 490        client: &Arc<Client>,
 491        user_store: &Entity<UserStore>,
 492        cx: &mut App,
 493    ) -> Entity<Self> {
 494        cx.try_global::<ZetaGlobal>()
 495            .map(|global| global.0.clone())
 496            .unwrap_or_else(|| {
 497                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 498                cx.set_global(ZetaGlobal(zeta.clone()));
 499                zeta
 500            })
 501    }
 502
 503    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 504        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 505        let data_collection_choice = Self::load_data_collection_choice();
 506
 507        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 508        cx.spawn(async move |this, cx| {
 509            while let Some(()) = reject_rx.next().await {
 510                this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
 511                    .await
 512                    .log_err();
 513            }
 514            anyhow::Ok(())
 515        })
 516        .detach();
 517
 518        Self {
 519            projects: HashMap::default(),
 520            client,
 521            user_store,
 522            options: DEFAULT_OPTIONS,
 523            llm_token: LlmApiToken::default(),
 524            _llm_token_subscription: cx.subscribe(
 525                &refresh_llm_token_listener,
 526                |this, _listener, _event, cx| {
 527                    let client = this.client.clone();
 528                    let llm_token = this.llm_token.clone();
 529                    cx.spawn(async move |_this, _cx| {
 530                        llm_token.refresh(&client).await?;
 531                        anyhow::Ok(())
 532                    })
 533                    .detach_and_log_err(cx);
 534                },
 535            ),
 536            update_required: false,
 537            debug_tx: None,
 538            #[cfg(feature = "eval-support")]
 539            eval_cache: None,
 540            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 541            sweep_ai: SweepAi::new(cx),
 542            data_collection_choice,
 543            rejected_predictions: Vec::new(),
 544            reject_predictions_debounce_task: None,
 545            reject_predictions_tx: reject_tx,
 546            rated_predictions: Default::default(),
 547            shown_predictions: Default::default(),
 548        }
 549    }
 550
 551    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 552        self.edit_prediction_model = model;
 553    }
 554
 555    pub fn has_sweep_api_token(&self) -> bool {
 556        self.sweep_ai
 557            .api_token
 558            .clone()
 559            .now_or_never()
 560            .flatten()
 561            .is_some()
 562    }
 563
 564    #[cfg(feature = "eval-support")]
 565    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 566        self.eval_cache = Some(cache);
 567    }
 568
 569    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 570        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 571        self.debug_tx = Some(debug_watch_tx);
 572        debug_watch_rx
 573    }
 574
 575    pub fn options(&self) -> &ZetaOptions {
 576        &self.options
 577    }
 578
 579    pub fn set_options(&mut self, options: ZetaOptions) {
 580        self.options = options;
 581    }
 582
 583    pub fn clear_history(&mut self) {
 584        for zeta_project in self.projects.values_mut() {
 585            zeta_project.events.clear();
 586        }
 587    }
 588
 589    pub fn context_for_project(
 590        &self,
 591        project: &Entity<Project>,
 592    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 593        self.projects
 594            .get(&project.entity_id())
 595            .and_then(|project| {
 596                Some(
 597                    project
 598                        .context
 599                        .as_ref()?
 600                        .iter()
 601                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 602                )
 603            })
 604            .into_iter()
 605            .flatten()
 606    }
 607
 608    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 609        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 610            self.user_store.read(cx).edit_prediction_usage()
 611        } else {
 612            None
 613        }
 614    }
 615
 616    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 617        self.get_or_init_zeta_project(project, cx);
 618    }
 619
 620    pub fn register_buffer(
 621        &mut self,
 622        buffer: &Entity<Buffer>,
 623        project: &Entity<Project>,
 624        cx: &mut Context<Self>,
 625    ) {
 626        let zeta_project = self.get_or_init_zeta_project(project, cx);
 627        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 628    }
 629
 630    fn get_or_init_zeta_project(
 631        &mut self,
 632        project: &Entity<Project>,
 633        cx: &mut Context<Self>,
 634    ) -> &mut ZetaProject {
 635        self.projects
 636            .entry(project.entity_id())
 637            .or_insert_with(|| ZetaProject {
 638                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 639                    Some(cx.new(|cx| {
 640                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 641                    }))
 642                } else {
 643                    None
 644                },
 645                events: VecDeque::new(),
 646                last_event: None,
 647                recent_paths: VecDeque::new(),
 648                registered_buffers: HashMap::default(),
 649                current_prediction: None,
 650                cancelled_predictions: HashSet::default(),
 651                pending_predictions: ArrayVec::new(),
 652                next_pending_prediction_id: 0,
 653                last_prediction_refresh: None,
 654                context: None,
 655                refresh_context_task: None,
 656                refresh_context_debounce_task: None,
 657                refresh_context_timestamp: None,
 658                license_detection_watchers: HashMap::default(),
 659                _subscription: cx.subscribe(&project, Self::handle_project_event),
 660            })
 661    }
 662
 663    fn handle_project_event(
 664        &mut self,
 665        project: Entity<Project>,
 666        event: &project::Event,
 667        cx: &mut Context<Self>,
 668    ) {
 669        // TODO [zeta2] init with recent paths
 670        match event {
 671            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 672                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 673                    return;
 674                };
 675                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 676                if let Some(path) = path {
 677                    if let Some(ix) = zeta_project
 678                        .recent_paths
 679                        .iter()
 680                        .position(|probe| probe == &path)
 681                    {
 682                        zeta_project.recent_paths.remove(ix);
 683                    }
 684                    zeta_project.recent_paths.push_front(path);
 685                }
 686            }
 687            project::Event::DiagnosticsUpdated { .. } => {
 688                if cx.has_flag::<Zeta2FeatureFlag>() {
 689                    self.refresh_prediction_from_diagnostics(project, cx);
 690                }
 691            }
 692            _ => (),
 693        }
 694    }
 695
 696    fn register_buffer_impl<'a>(
 697        zeta_project: &'a mut ZetaProject,
 698        buffer: &Entity<Buffer>,
 699        project: &Entity<Project>,
 700        cx: &mut Context<Self>,
 701    ) -> &'a mut RegisteredBuffer {
 702        let buffer_id = buffer.entity_id();
 703
 704        if let Some(file) = buffer.read(cx).file() {
 705            let worktree_id = file.worktree_id(cx);
 706            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 707                zeta_project
 708                    .license_detection_watchers
 709                    .entry(worktree_id)
 710                    .or_insert_with(|| {
 711                        let project_entity_id = project.entity_id();
 712                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 713                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 714                            else {
 715                                return;
 716                            };
 717                            zeta_project.license_detection_watchers.remove(&worktree_id);
 718                        })
 719                        .detach();
 720                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 721                    });
 722            }
 723        }
 724
 725        match zeta_project.registered_buffers.entry(buffer_id) {
 726            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 727            hash_map::Entry::Vacant(entry) => {
 728                let snapshot = buffer.read(cx).snapshot();
 729                let project_entity_id = project.entity_id();
 730                entry.insert(RegisteredBuffer {
 731                    snapshot,
 732                    _subscriptions: [
 733                        cx.subscribe(buffer, {
 734                            let project = project.downgrade();
 735                            move |this, buffer, event, cx| {
 736                                if let language::BufferEvent::Edited = event
 737                                    && let Some(project) = project.upgrade()
 738                                {
 739                                    this.report_changes_for_buffer(&buffer, &project, cx);
 740                                }
 741                            }
 742                        }),
 743                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 744                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 745                            else {
 746                                return;
 747                            };
 748                            zeta_project.registered_buffers.remove(&buffer_id);
 749                        }),
 750                    ],
 751                })
 752            }
 753        }
 754    }
 755
 756    fn report_changes_for_buffer(
 757        &mut self,
 758        buffer: &Entity<Buffer>,
 759        project: &Entity<Project>,
 760        cx: &mut Context<Self>,
 761    ) {
 762        let project_state = self.get_or_init_zeta_project(project, cx);
 763        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 764
 765        let new_snapshot = buffer.read(cx).snapshot();
 766        if new_snapshot.version == registered_buffer.snapshot.version {
 767            return;
 768        }
 769
 770        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 771        let end_edit_anchor = new_snapshot
 772            .anchored_edits_since::<Point>(&old_snapshot.version)
 773            .last()
 774            .map(|(_, range)| range.end);
 775        let events = &mut project_state.events;
 776
 777        if let Some(LastEvent {
 778            new_snapshot: last_new_snapshot,
 779            end_edit_anchor: last_end_edit_anchor,
 780            ..
 781        }) = project_state.last_event.as_mut()
 782        {
 783            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 784                == last_new_snapshot.remote_id()
 785                && old_snapshot.version == last_new_snapshot.version;
 786
 787            let should_coalesce = is_next_snapshot_of_same_buffer
 788                && end_edit_anchor
 789                    .as_ref()
 790                    .zip(last_end_edit_anchor.as_ref())
 791                    .is_some_and(|(a, b)| {
 792                        let a = a.to_point(&new_snapshot);
 793                        let b = b.to_point(&new_snapshot);
 794                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 795                    });
 796
 797            if should_coalesce {
 798                *last_end_edit_anchor = end_edit_anchor;
 799                *last_new_snapshot = new_snapshot;
 800                return;
 801            }
 802        }
 803
 804        if events.len() + 1 >= EVENT_COUNT_MAX {
 805            events.pop_front();
 806        }
 807
 808        if let Some(event) = project_state.last_event.take() {
 809            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 810        }
 811
 812        project_state.last_event = Some(LastEvent {
 813            old_snapshot,
 814            new_snapshot,
 815            end_edit_anchor,
 816        });
 817    }
 818
 819    fn current_prediction_for_buffer(
 820        &self,
 821        buffer: &Entity<Buffer>,
 822        project: &Entity<Project>,
 823        cx: &App,
 824    ) -> Option<BufferEditPrediction<'_>> {
 825        let project_state = self.projects.get(&project.entity_id())?;
 826
 827        let CurrentEditPrediction {
 828            requested_by,
 829            prediction,
 830            ..
 831        } = project_state.current_prediction.as_ref()?;
 832
 833        if prediction.targets_buffer(buffer.read(cx)) {
 834            Some(BufferEditPrediction::Local { prediction })
 835        } else {
 836            let show_jump = match requested_by {
 837                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 838                    requested_by_buffer_id == &buffer.entity_id()
 839                }
 840                PredictionRequestedBy::DiagnosticsUpdate => true,
 841            };
 842
 843            if show_jump {
 844                Some(BufferEditPrediction::Jump { prediction })
 845            } else {
 846                None
 847            }
 848        }
 849    }
 850
 851    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 852        match self.edit_prediction_model {
 853            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 854            ZetaEditPredictionModel::Sweep => return,
 855        }
 856
 857        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 858            return;
 859        };
 860
 861        let Some(prediction) = project_state.current_prediction.take() else {
 862            return;
 863        };
 864        let request_id = prediction.prediction.id.to_string();
 865        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 866            project_state.cancel_pending_prediction(pending_prediction, cx);
 867        }
 868
 869        let client = self.client.clone();
 870        let llm_token = self.llm_token.clone();
 871        let app_version = AppVersion::global(cx);
 872        cx.spawn(async move |this, cx| {
 873            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 874                http_client::Url::parse(&predict_edits_url)?
 875            } else {
 876                client
 877                    .http_client()
 878                    .build_zed_llm_url("/predict_edits/accept", &[])?
 879            };
 880
 881            let response = cx
 882                .background_spawn(Self::send_api_request::<()>(
 883                    move |builder| {
 884                        let req = builder.uri(url.as_ref()).body(
 885                            serde_json::to_string(&AcceptEditPredictionBody {
 886                                request_id: request_id.clone(),
 887                            })?
 888                            .into(),
 889                        );
 890                        Ok(req?)
 891                    },
 892                    client,
 893                    llm_token,
 894                    app_version,
 895                ))
 896                .await;
 897
 898            Self::handle_api_response(&this, response, cx)?;
 899            anyhow::Ok(())
 900        })
 901        .detach_and_log_err(cx);
 902    }
 903
 904    fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 905        match self.edit_prediction_model {
 906            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 907            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 908        }
 909
 910        let client = self.client.clone();
 911        let llm_token = self.llm_token.clone();
 912        let app_version = AppVersion::global(cx);
 913        let last_rejection = self.rejected_predictions.last().cloned();
 914        let Some(last_rejection) = last_rejection else {
 915            return Task::ready(anyhow::Ok(()));
 916        };
 917
 918        let body = serde_json::to_string(&RejectEditPredictionsBody {
 919            rejections: self.rejected_predictions.clone(),
 920        })
 921        .ok();
 922
 923        cx.spawn(async move |this, cx| {
 924            let url = client
 925                .http_client()
 926                .build_zed_llm_url("/predict_edits/reject", &[])?;
 927
 928            cx.background_spawn(Self::send_api_request::<()>(
 929                move |builder| {
 930                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 931                    Ok(req?)
 932                },
 933                client,
 934                llm_token,
 935                app_version,
 936            ))
 937            .await
 938            .context("Failed to reject edit predictions")?;
 939
 940            this.update(cx, |this, _| {
 941                if let Some(ix) = this
 942                    .rejected_predictions
 943                    .iter()
 944                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 945                {
 946                    this.rejected_predictions.drain(..ix + 1);
 947                }
 948            })
 949        })
 950    }
 951
 952    fn reject_current_prediction(
 953        &mut self,
 954        reason: EditPredictionRejectReason,
 955        project: &Entity<Project>,
 956        cx: &mut Context<Self>,
 957    ) {
 958        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 959            project_state.pending_predictions.clear();
 960            if let Some(prediction) = project_state.current_prediction.take() {
 961                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
 962            }
 963        };
 964    }
 965
 966    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 967        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 968            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 969                if !current_prediction.was_shown {
 970                    current_prediction.was_shown = true;
 971                    self.shown_predictions
 972                        .push_front(current_prediction.prediction.clone());
 973                    if self.shown_predictions.len() > 50 {
 974                        let completion = self.shown_predictions.pop_back().unwrap();
 975                        self.rated_predictions.remove(&completion.id);
 976                    }
 977                }
 978            }
 979        }
 980    }
 981
 982    fn reject_prediction(
 983        &mut self,
 984        prediction_id: EditPredictionId,
 985        reason: EditPredictionRejectReason,
 986        was_shown: bool,
 987        cx: &mut Context<Self>,
 988    ) {
 989        self.rejected_predictions.push(EditPredictionRejection {
 990            request_id: prediction_id.to_string(),
 991            reason,
 992            was_shown,
 993        });
 994
 995        let reached_request_limit =
 996            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2;
 997        let reject_tx = self.reject_predictions_tx.clone();
 998        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 999            const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
1000            if !reached_request_limit {
1001                cx.background_executor()
1002                    .timer(REJECT_REQUEST_DEBOUNCE)
1003                    .await;
1004            }
1005            reject_tx.unbounded_send(()).log_err();
1006        }));
1007    }
1008
1009    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1010        self.projects
1011            .get(&project.entity_id())
1012            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1013    }
1014
1015    pub fn refresh_prediction_from_buffer(
1016        &mut self,
1017        project: Entity<Project>,
1018        buffer: Entity<Buffer>,
1019        position: language::Anchor,
1020        cx: &mut Context<Self>,
1021    ) {
1022        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1023            let Some(request_task) = this
1024                .update(cx, |this, cx| {
1025                    this.request_prediction(
1026                        &project,
1027                        &buffer,
1028                        position,
1029                        PredictEditsRequestTrigger::Other,
1030                        cx,
1031                    )
1032                })
1033                .log_err()
1034            else {
1035                return Task::ready(anyhow::Ok(None));
1036            };
1037
1038            cx.spawn(async move |_cx| {
1039                request_task.await.map(|prediction_result| {
1040                    prediction_result.map(|prediction_result| {
1041                        (
1042                            prediction_result,
1043                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1044                        )
1045                    })
1046                })
1047            })
1048        })
1049    }
1050
1051    pub fn refresh_prediction_from_diagnostics(
1052        &mut self,
1053        project: Entity<Project>,
1054        cx: &mut Context<Self>,
1055    ) {
1056        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1057            return;
1058        };
1059
1060        // Prefer predictions from buffer
1061        if zeta_project.current_prediction.is_some() {
1062            return;
1063        };
1064
1065        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1066            let Some(open_buffer_task) = project
1067                .update(cx, |project, cx| {
1068                    project
1069                        .active_entry()
1070                        .and_then(|entry| project.path_for_entry(entry, cx))
1071                        .map(|path| project.open_buffer(path, cx))
1072                })
1073                .log_err()
1074                .flatten()
1075            else {
1076                return Task::ready(anyhow::Ok(None));
1077            };
1078
1079            cx.spawn(async move |cx| {
1080                let active_buffer = open_buffer_task.await?;
1081                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1082
1083                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1084                    active_buffer,
1085                    &snapshot,
1086                    Default::default(),
1087                    Default::default(),
1088                    &project,
1089                    cx,
1090                )
1091                .await?
1092                else {
1093                    return anyhow::Ok(None);
1094                };
1095
1096                let Some(prediction_result) = this
1097                    .update(cx, |this, cx| {
1098                        this.request_prediction(
1099                            &project,
1100                            &jump_buffer,
1101                            jump_position,
1102                            PredictEditsRequestTrigger::Diagnostics,
1103                            cx,
1104                        )
1105                    })?
1106                    .await?
1107                else {
1108                    return anyhow::Ok(None);
1109                };
1110
1111                this.update(cx, |this, cx| {
1112                    Some((
1113                        if this
1114                            .get_or_init_zeta_project(&project, cx)
1115                            .current_prediction
1116                            .is_none()
1117                        {
1118                            prediction_result
1119                        } else {
1120                            EditPredictionResult {
1121                                id: prediction_result.id,
1122                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1123                            }
1124                        },
1125                        PredictionRequestedBy::DiagnosticsUpdate,
1126                    ))
1127                })
1128            })
1129        });
1130    }
1131
1132    #[cfg(not(test))]
1133    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1134    #[cfg(test)]
1135    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1136
1137    fn queue_prediction_refresh(
1138        &mut self,
1139        project: Entity<Project>,
1140        throttle_entity: EntityId,
1141        cx: &mut Context<Self>,
1142        do_refresh: impl FnOnce(
1143            WeakEntity<Self>,
1144            &mut AsyncApp,
1145        )
1146            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1147        + 'static,
1148    ) {
1149        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1150        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1151        zeta_project.next_pending_prediction_id += 1;
1152        let last_request = zeta_project.last_prediction_refresh;
1153
1154        let task = cx.spawn(async move |this, cx| {
1155            if let Some((last_entity, last_timestamp)) = last_request
1156                && throttle_entity == last_entity
1157                && let Some(timeout) =
1158                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1159            {
1160                cx.background_executor().timer(timeout).await;
1161            }
1162
1163            // If this task was cancelled before the throttle timeout expired,
1164            // do not perform a request.
1165            let mut is_cancelled = true;
1166            this.update(cx, |this, cx| {
1167                let project_state = this.get_or_init_zeta_project(&project, cx);
1168                if !project_state
1169                    .cancelled_predictions
1170                    .remove(&pending_prediction_id)
1171                {
1172                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1173                    is_cancelled = false;
1174                }
1175            })
1176            .ok();
1177            if is_cancelled {
1178                return None;
1179            }
1180
1181            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1182            let new_prediction_id = new_prediction_result
1183                .as_ref()
1184                .map(|(prediction, _)| prediction.id.clone());
1185
1186            // When a prediction completes, remove it from the pending list, and cancel
1187            // any pending predictions that were enqueued before it.
1188            this.update(cx, |this, cx| {
1189                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1190
1191                let is_cancelled = zeta_project
1192                    .cancelled_predictions
1193                    .remove(&pending_prediction_id);
1194
1195                let new_current_prediction = if !is_cancelled
1196                    && let Some((prediction_result, requested_by)) = new_prediction_result
1197                {
1198                    match prediction_result.prediction {
1199                        Ok(prediction) => {
1200                            let new_prediction = CurrentEditPrediction {
1201                                requested_by,
1202                                prediction,
1203                                was_shown: false,
1204                            };
1205
1206                            if let Some(current_prediction) =
1207                                zeta_project.current_prediction.as_ref()
1208                            {
1209                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1210                                {
1211                                    this.reject_current_prediction(
1212                                        EditPredictionRejectReason::Replaced,
1213                                        &project,
1214                                        cx,
1215                                    );
1216
1217                                    Some(new_prediction)
1218                                } else {
1219                                    this.reject_prediction(
1220                                        new_prediction.prediction.id,
1221                                        EditPredictionRejectReason::CurrentPreferred,
1222                                        false,
1223                                        cx,
1224                                    );
1225                                    None
1226                                }
1227                            } else {
1228                                Some(new_prediction)
1229                            }
1230                        }
1231                        Err(reject_reason) => {
1232                            this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1233                            None
1234                        }
1235                    }
1236                } else {
1237                    None
1238                };
1239
1240                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1241
1242                if let Some(new_prediction) = new_current_prediction {
1243                    zeta_project.current_prediction = Some(new_prediction);
1244                }
1245
1246                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1247                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1248                    if pending_prediction.id == pending_prediction_id {
1249                        pending_predictions.remove(ix);
1250                        for pending_prediction in pending_predictions.drain(0..ix) {
1251                            zeta_project.cancel_pending_prediction(pending_prediction, cx)
1252                        }
1253                        break;
1254                    }
1255                }
1256                this.get_or_init_zeta_project(&project, cx)
1257                    .pending_predictions = pending_predictions;
1258                cx.notify();
1259            })
1260            .ok();
1261
1262            new_prediction_id
1263        });
1264
1265        if zeta_project.pending_predictions.len() <= 1 {
1266            zeta_project.pending_predictions.push(PendingPrediction {
1267                id: pending_prediction_id,
1268                task,
1269            });
1270        } else if zeta_project.pending_predictions.len() == 2 {
1271            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1272            zeta_project.pending_predictions.push(PendingPrediction {
1273                id: pending_prediction_id,
1274                task,
1275            });
1276            zeta_project.cancel_pending_prediction(pending_prediction, cx);
1277        }
1278    }
1279
1280    pub fn request_prediction(
1281        &mut self,
1282        project: &Entity<Project>,
1283        active_buffer: &Entity<Buffer>,
1284        position: language::Anchor,
1285        trigger: PredictEditsRequestTrigger,
1286        cx: &mut Context<Self>,
1287    ) -> Task<Result<Option<EditPredictionResult>>> {
1288        self.request_prediction_internal(
1289            project.clone(),
1290            active_buffer.clone(),
1291            position,
1292            trigger,
1293            cx.has_flag::<Zeta2FeatureFlag>(),
1294            cx,
1295        )
1296    }
1297
1298    fn request_prediction_internal(
1299        &mut self,
1300        project: Entity<Project>,
1301        active_buffer: Entity<Buffer>,
1302        position: language::Anchor,
1303        trigger: PredictEditsRequestTrigger,
1304        allow_jump: bool,
1305        cx: &mut Context<Self>,
1306    ) -> Task<Result<Option<EditPredictionResult>>> {
1307        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1308
1309        self.get_or_init_zeta_project(&project, cx);
1310        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1311        let events = zeta_project.events(cx);
1312        let has_events = !events.is_empty();
1313
1314        let snapshot = active_buffer.read(cx).snapshot();
1315        let cursor_point = position.to_point(&snapshot);
1316        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1317        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1318        let diagnostic_search_range =
1319            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1320
1321        let task = match self.edit_prediction_model {
1322            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1323                self,
1324                &project,
1325                &active_buffer,
1326                snapshot.clone(),
1327                position,
1328                events,
1329                trigger,
1330                cx,
1331            ),
1332            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1333                &project,
1334                &active_buffer,
1335                snapshot.clone(),
1336                position,
1337                events,
1338                trigger,
1339                cx,
1340            ),
1341            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1342                &project,
1343                &active_buffer,
1344                snapshot.clone(),
1345                position,
1346                events,
1347                &zeta_project.recent_paths,
1348                diagnostic_search_range.clone(),
1349                cx,
1350            ),
1351        };
1352
1353        cx.spawn(async move |this, cx| {
1354            let prediction = task.await?;
1355
1356            if prediction.is_none() && allow_jump {
1357                let cursor_point = position.to_point(&snapshot);
1358                if has_events
1359                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1360                        active_buffer.clone(),
1361                        &snapshot,
1362                        diagnostic_search_range,
1363                        cursor_point,
1364                        &project,
1365                        cx,
1366                    )
1367                    .await?
1368                {
1369                    return this
1370                        .update(cx, |this, cx| {
1371                            this.request_prediction_internal(
1372                                project,
1373                                jump_buffer,
1374                                jump_position,
1375                                trigger,
1376                                false,
1377                                cx,
1378                            )
1379                        })?
1380                        .await;
1381                }
1382
1383                return anyhow::Ok(None);
1384            }
1385
1386            Ok(prediction)
1387        })
1388    }
1389
1390    async fn next_diagnostic_location(
1391        active_buffer: Entity<Buffer>,
1392        active_buffer_snapshot: &BufferSnapshot,
1393        active_buffer_diagnostic_search_range: Range<Point>,
1394        active_buffer_cursor_point: Point,
1395        project: &Entity<Project>,
1396        cx: &mut AsyncApp,
1397    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1398        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1399        let mut jump_location = active_buffer_snapshot
1400            .diagnostic_groups(None)
1401            .into_iter()
1402            .filter_map(|(_, group)| {
1403                let range = &group.entries[group.primary_ix]
1404                    .range
1405                    .to_point(&active_buffer_snapshot);
1406                if range.overlaps(&active_buffer_diagnostic_search_range) {
1407                    None
1408                } else {
1409                    Some(range.start)
1410                }
1411            })
1412            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1413            .map(|position| {
1414                (
1415                    active_buffer.clone(),
1416                    active_buffer_snapshot.anchor_before(position),
1417                )
1418            });
1419
1420        if jump_location.is_none() {
1421            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1422                let file = buffer.file()?;
1423
1424                Some(ProjectPath {
1425                    worktree_id: file.worktree_id(cx),
1426                    path: file.path().clone(),
1427                })
1428            })?;
1429
1430            let buffer_task = project.update(cx, |project, cx| {
1431                let (path, _, _) = project
1432                    .diagnostic_summaries(false, cx)
1433                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1434                    .max_by_key(|(path, _, _)| {
1435                        // find the buffer with errors that shares most parent directories
1436                        path.path
1437                            .components()
1438                            .zip(
1439                                active_buffer_path
1440                                    .as_ref()
1441                                    .map(|p| p.path.components())
1442                                    .unwrap_or_default(),
1443                            )
1444                            .take_while(|(a, b)| a == b)
1445                            .count()
1446                    })?;
1447
1448                Some(project.open_buffer(path, cx))
1449            })?;
1450
1451            if let Some(buffer_task) = buffer_task {
1452                let closest_buffer = buffer_task.await?;
1453
1454                jump_location = closest_buffer
1455                    .read_with(cx, |buffer, _cx| {
1456                        buffer
1457                            .buffer_diagnostics(None)
1458                            .into_iter()
1459                            .min_by_key(|entry| entry.diagnostic.severity)
1460                            .map(|entry| entry.range.start)
1461                    })?
1462                    .map(|position| (closest_buffer, position));
1463            }
1464        }
1465
1466        anyhow::Ok(jump_location)
1467    }
1468
1469    fn request_prediction_with_zeta2(
1470        &mut self,
1471        project: &Entity<Project>,
1472        active_buffer: &Entity<Buffer>,
1473        active_snapshot: BufferSnapshot,
1474        position: language::Anchor,
1475        events: Vec<Arc<Event>>,
1476        trigger: PredictEditsRequestTrigger,
1477        cx: &mut Context<Self>,
1478    ) -> Task<Result<Option<EditPredictionResult>>> {
1479        let project_state = self.projects.get(&project.entity_id());
1480
1481        let index_state = project_state.and_then(|state| {
1482            state
1483                .syntax_index
1484                .as_ref()
1485                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1486        });
1487        let options = self.options.clone();
1488        let buffer_snapshotted_at = Instant::now();
1489        let Some(excerpt_path) = active_snapshot
1490            .file()
1491            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1492        else {
1493            return Task::ready(Err(anyhow!("No file path for excerpt")));
1494        };
1495        let client = self.client.clone();
1496        let llm_token = self.llm_token.clone();
1497        let app_version = AppVersion::global(cx);
1498        let worktree_snapshots = project
1499            .read(cx)
1500            .worktrees(cx)
1501            .map(|worktree| worktree.read(cx).snapshot())
1502            .collect::<Vec<_>>();
1503        let debug_tx = self.debug_tx.clone();
1504
1505        let diagnostics = active_snapshot.diagnostic_sets().clone();
1506
1507        let file = active_buffer.read(cx).file();
1508        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1509            let mut path = f.worktree.read(cx).absolutize(&f.path);
1510            if path.pop() { Some(path) } else { None }
1511        });
1512
1513        // TODO data collection
1514        let can_collect_data = file
1515            .as_ref()
1516            .map_or(false, |file| self.can_collect_file(project, file, cx));
1517
1518        let empty_context_files = HashMap::default();
1519        let context_files = project_state
1520            .and_then(|project_state| project_state.context.as_ref())
1521            .unwrap_or(&empty_context_files);
1522
1523        #[cfg(feature = "eval-support")]
1524        let parsed_fut = futures::future::join_all(
1525            context_files
1526                .keys()
1527                .map(|buffer| buffer.read(cx).parsing_idle()),
1528        );
1529
1530        let mut included_files = context_files
1531            .iter()
1532            .filter_map(|(buffer_entity, ranges)| {
1533                let buffer = buffer_entity.read(cx);
1534                Some((
1535                    buffer_entity.clone(),
1536                    buffer.snapshot(),
1537                    buffer.file()?.full_path(cx).into(),
1538                    ranges.clone(),
1539                ))
1540            })
1541            .collect::<Vec<_>>();
1542
1543        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1544            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1545        });
1546
1547        #[cfg(feature = "eval-support")]
1548        let eval_cache = self.eval_cache.clone();
1549
1550        let request_task = cx.background_spawn({
1551            let active_buffer = active_buffer.clone();
1552            async move {
1553                #[cfg(feature = "eval-support")]
1554                parsed_fut.await;
1555
1556                let index_state = if let Some(index_state) = index_state {
1557                    Some(index_state.lock_owned().await)
1558                } else {
1559                    None
1560                };
1561
1562                let cursor_offset = position.to_offset(&active_snapshot);
1563                let cursor_point = cursor_offset.to_point(&active_snapshot);
1564
1565                let before_retrieval = Instant::now();
1566
1567                let (diagnostic_groups, diagnostic_groups_truncated) =
1568                    Self::gather_nearby_diagnostics(
1569                        cursor_offset,
1570                        &diagnostics,
1571                        &active_snapshot,
1572                        options.max_diagnostic_bytes,
1573                    );
1574
1575                let cloud_request = match options.context {
1576                    ContextMode::Agentic(context_options) => {
1577                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1578                            cursor_point,
1579                            &active_snapshot,
1580                            &context_options.excerpt,
1581                            index_state.as_deref(),
1582                        ) else {
1583                            return Ok((None, None));
1584                        };
1585
1586                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1587                            ..active_snapshot.anchor_before(excerpt.range.end);
1588
1589                        if let Some(buffer_ix) =
1590                            included_files.iter().position(|(_, snapshot, _, _)| {
1591                                snapshot.remote_id() == active_snapshot.remote_id()
1592                            })
1593                        {
1594                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1595                            ranges.push(excerpt_anchor_range);
1596                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1597                            let last_ix = included_files.len() - 1;
1598                            included_files.swap(buffer_ix, last_ix);
1599                        } else {
1600                            included_files.push((
1601                                active_buffer.clone(),
1602                                active_snapshot.clone(),
1603                                excerpt_path.clone(),
1604                                vec![excerpt_anchor_range],
1605                            ));
1606                        }
1607
1608                        let included_files = included_files
1609                            .iter()
1610                            .map(|(_, snapshot, path, ranges)| {
1611                                let ranges = ranges
1612                                    .iter()
1613                                    .map(|range| {
1614                                        let point_range = range.to_point(&snapshot);
1615                                        Line(point_range.start.row)..Line(point_range.end.row)
1616                                    })
1617                                    .collect::<Vec<_>>();
1618                                let excerpts = assemble_excerpts(&snapshot, ranges);
1619                                predict_edits_v3::IncludedFile {
1620                                    path: path.clone(),
1621                                    max_row: Line(snapshot.max_point().row),
1622                                    excerpts,
1623                                }
1624                            })
1625                            .collect::<Vec<_>>();
1626
1627                        predict_edits_v3::PredictEditsRequest {
1628                            excerpt_path,
1629                            excerpt: String::new(),
1630                            excerpt_line_range: Line(0)..Line(0),
1631                            excerpt_range: 0..0,
1632                            cursor_point: predict_edits_v3::Point {
1633                                line: predict_edits_v3::Line(cursor_point.row),
1634                                column: cursor_point.column,
1635                            },
1636                            included_files,
1637                            referenced_declarations: vec![],
1638                            events,
1639                            can_collect_data,
1640                            diagnostic_groups,
1641                            diagnostic_groups_truncated,
1642                            debug_info: debug_tx.is_some(),
1643                            prompt_max_bytes: Some(options.max_prompt_bytes),
1644                            prompt_format: options.prompt_format,
1645                            // TODO [zeta2]
1646                            signatures: vec![],
1647                            excerpt_parent: None,
1648                            git_info: None,
1649                            trigger,
1650                        }
1651                    }
1652                    ContextMode::Syntax(context_options) => {
1653                        let Some(context) = EditPredictionContext::gather_context(
1654                            cursor_point,
1655                            &active_snapshot,
1656                            parent_abs_path.as_deref(),
1657                            &context_options,
1658                            index_state.as_deref(),
1659                        ) else {
1660                            return Ok((None, None));
1661                        };
1662
1663                        make_syntax_context_cloud_request(
1664                            excerpt_path,
1665                            context,
1666                            events,
1667                            can_collect_data,
1668                            diagnostic_groups,
1669                            diagnostic_groups_truncated,
1670                            None,
1671                            debug_tx.is_some(),
1672                            &worktree_snapshots,
1673                            index_state.as_deref(),
1674                            Some(options.max_prompt_bytes),
1675                            options.prompt_format,
1676                            trigger,
1677                        )
1678                    }
1679                };
1680
1681                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1682
1683                let inputs = EditPredictionInputs {
1684                    included_files: cloud_request.included_files,
1685                    events: cloud_request.events,
1686                    cursor_point: cloud_request.cursor_point,
1687                    cursor_path: cloud_request.excerpt_path,
1688                };
1689
1690                let retrieval_time = Instant::now() - before_retrieval;
1691
1692                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1693                    let (response_tx, response_rx) = oneshot::channel();
1694
1695                    debug_tx
1696                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1697                            ZetaEditPredictionDebugInfo {
1698                                inputs: inputs.clone(),
1699                                retrieval_time,
1700                                buffer: active_buffer.downgrade(),
1701                                local_prompt: match prompt_result.as_ref() {
1702                                    Ok((prompt, _)) => Ok(prompt.clone()),
1703                                    Err(err) => Err(err.to_string()),
1704                                },
1705                                position,
1706                                response_rx,
1707                            },
1708                        ))
1709                        .ok();
1710                    Some(response_tx)
1711                } else {
1712                    None
1713                };
1714
1715                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1716                    if let Some(debug_response_tx) = debug_response_tx {
1717                        debug_response_tx
1718                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1719                            .ok();
1720                    }
1721                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1722                }
1723
1724                let (prompt, _) = prompt_result?;
1725                let generation_params =
1726                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1727                let request = open_ai::Request {
1728                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1729                    messages: vec![open_ai::RequestMessage::User {
1730                        content: open_ai::MessageContent::Plain(prompt),
1731                    }],
1732                    stream: false,
1733                    max_completion_tokens: None,
1734                    stop: generation_params.stop.unwrap_or_default(),
1735                    temperature: generation_params.temperature.unwrap_or(0.7),
1736                    tool_choice: None,
1737                    parallel_tool_calls: None,
1738                    tools: vec![],
1739                    prompt_cache_key: None,
1740                    reasoning_effort: None,
1741                };
1742
1743                log::trace!("Sending edit prediction request");
1744
1745                let before_request = Instant::now();
1746                let response = Self::send_raw_llm_request(
1747                    request,
1748                    client,
1749                    llm_token,
1750                    app_version,
1751                    #[cfg(feature = "eval-support")]
1752                    eval_cache,
1753                    #[cfg(feature = "eval-support")]
1754                    EvalCacheEntryKind::Prediction,
1755                )
1756                .await;
1757                let received_response_at = Instant::now();
1758                let request_time = received_response_at - before_request;
1759
1760                log::trace!("Got edit prediction response");
1761
1762                if let Some(debug_response_tx) = debug_response_tx {
1763                    debug_response_tx
1764                        .send((
1765                            response
1766                                .as_ref()
1767                                .map_err(|err| err.to_string())
1768                                .map(|response| response.0.clone()),
1769                            request_time,
1770                        ))
1771                        .ok();
1772                }
1773
1774                let (res, usage) = response?;
1775                let request_id = EditPredictionId(res.id.clone().into());
1776                let Some(mut output_text) = text_from_response(res) else {
1777                    return Ok((Some((request_id, None)), usage));
1778                };
1779
1780                if output_text.contains(CURSOR_MARKER) {
1781                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1782                    output_text = output_text.replace(CURSOR_MARKER, "");
1783                }
1784
1785                let get_buffer_from_context = |path: &Path| {
1786                    included_files
1787                        .iter()
1788                        .find_map(|(_, buffer, probe_path, ranges)| {
1789                            if probe_path.as_ref() == path {
1790                                Some((buffer, ranges.as_slice()))
1791                            } else {
1792                                None
1793                            }
1794                        })
1795                };
1796
1797                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1798                    PromptFormat::NumLinesUniDiff => {
1799                        // TODO: Implement parsing of multi-file diffs
1800                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1801                    }
1802                    PromptFormat::Minimal
1803                    | PromptFormat::MinimalQwen
1804                    | PromptFormat::SeedCoder1120 => {
1805                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1806                            let edits = vec![];
1807                            (&active_snapshot, edits)
1808                        } else {
1809                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1810                        }
1811                    }
1812                    PromptFormat::OldTextNewText => {
1813                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1814                            .await?
1815                    }
1816                    _ => {
1817                        bail!("unsupported prompt format {}", options.prompt_format)
1818                    }
1819                };
1820
1821                let edited_buffer = included_files
1822                    .iter()
1823                    .find_map(|(buffer, snapshot, _, _)| {
1824                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1825                            Some(buffer.clone())
1826                        } else {
1827                            None
1828                        }
1829                    })
1830                    .context("Failed to find buffer in included_buffers")?;
1831
1832                anyhow::Ok((
1833                    Some((
1834                        request_id,
1835                        Some((
1836                            inputs,
1837                            edited_buffer,
1838                            edited_buffer_snapshot.clone(),
1839                            edits,
1840                            received_response_at,
1841                        )),
1842                    )),
1843                    usage,
1844                ))
1845            }
1846        });
1847
1848        cx.spawn({
1849            async move |this, cx| {
1850                let Some((id, prediction)) =
1851                    Self::handle_api_response(&this, request_task.await, cx)?
1852                else {
1853                    return Ok(None);
1854                };
1855
1856                let Some((
1857                    inputs,
1858                    edited_buffer,
1859                    edited_buffer_snapshot,
1860                    edits,
1861                    received_response_at,
1862                )) = prediction
1863                else {
1864                    return Ok(Some(EditPredictionResult {
1865                        id,
1866                        prediction: Err(EditPredictionRejectReason::Empty),
1867                    }));
1868                };
1869
1870                // TODO telemetry: duration, etc
1871                Ok(Some(
1872                    EditPredictionResult::new(
1873                        id,
1874                        &edited_buffer,
1875                        &edited_buffer_snapshot,
1876                        edits.into(),
1877                        buffer_snapshotted_at,
1878                        received_response_at,
1879                        inputs,
1880                        cx,
1881                    )
1882                    .await,
1883                ))
1884            }
1885        })
1886    }
1887
1888    async fn send_raw_llm_request(
1889        request: open_ai::Request,
1890        client: Arc<Client>,
1891        llm_token: LlmApiToken,
1892        app_version: Version,
1893        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1894        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1895    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1896        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1897            http_client::Url::parse(&predict_edits_url)?
1898        } else {
1899            client
1900                .http_client()
1901                .build_zed_llm_url("/predict_edits/raw", &[])?
1902        };
1903
1904        #[cfg(feature = "eval-support")]
1905        let cache_key = if let Some(cache) = eval_cache {
1906            use collections::FxHasher;
1907            use std::hash::{Hash, Hasher};
1908
1909            let mut hasher = FxHasher::default();
1910            url.hash(&mut hasher);
1911            let request_str = serde_json::to_string_pretty(&request)?;
1912            request_str.hash(&mut hasher);
1913            let hash = hasher.finish();
1914
1915            let key = (eval_cache_kind, hash);
1916            if let Some(response_str) = cache.read(key) {
1917                return Ok((serde_json::from_str(&response_str)?, None));
1918            }
1919
1920            Some((cache, request_str, key))
1921        } else {
1922            None
1923        };
1924
1925        let (response, usage) = Self::send_api_request(
1926            |builder| {
1927                let req = builder
1928                    .uri(url.as_ref())
1929                    .body(serde_json::to_string(&request)?.into());
1930                Ok(req?)
1931            },
1932            client,
1933            llm_token,
1934            app_version,
1935        )
1936        .await?;
1937
1938        #[cfg(feature = "eval-support")]
1939        if let Some((cache, request, key)) = cache_key {
1940            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1941        }
1942
1943        Ok((response, usage))
1944    }
1945
1946    fn handle_api_response<T>(
1947        this: &WeakEntity<Self>,
1948        response: Result<(T, Option<EditPredictionUsage>)>,
1949        cx: &mut gpui::AsyncApp,
1950    ) -> Result<T> {
1951        match response {
1952            Ok((data, usage)) => {
1953                if let Some(usage) = usage {
1954                    this.update(cx, |this, cx| {
1955                        this.user_store.update(cx, |user_store, cx| {
1956                            user_store.update_edit_prediction_usage(usage, cx);
1957                        });
1958                    })
1959                    .ok();
1960                }
1961                Ok(data)
1962            }
1963            Err(err) => {
1964                if err.is::<ZedUpdateRequiredError>() {
1965                    cx.update(|cx| {
1966                        this.update(cx, |this, _cx| {
1967                            this.update_required = true;
1968                        })
1969                        .ok();
1970
1971                        let error_message: SharedString = err.to_string().into();
1972                        show_app_notification(
1973                            NotificationId::unique::<ZedUpdateRequiredError>(),
1974                            cx,
1975                            move |cx| {
1976                                cx.new(|cx| {
1977                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1978                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1979                                })
1980                            },
1981                        );
1982                    })
1983                    .ok();
1984                }
1985                Err(err)
1986            }
1987        }
1988    }
1989
1990    async fn send_api_request<Res>(
1991        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1992        client: Arc<Client>,
1993        llm_token: LlmApiToken,
1994        app_version: Version,
1995    ) -> Result<(Res, Option<EditPredictionUsage>)>
1996    where
1997        Res: DeserializeOwned,
1998    {
1999        let http_client = client.http_client();
2000        let mut token = llm_token.acquire(&client).await?;
2001        let mut did_retry = false;
2002
2003        loop {
2004            let request_builder = http_client::Request::builder().method(Method::POST);
2005
2006            let request = build(
2007                request_builder
2008                    .header("Content-Type", "application/json")
2009                    .header("Authorization", format!("Bearer {}", token))
2010                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2011            )?;
2012
2013            let mut response = http_client.send(request).await?;
2014
2015            if let Some(minimum_required_version) = response
2016                .headers()
2017                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2018                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2019            {
2020                anyhow::ensure!(
2021                    app_version >= minimum_required_version,
2022                    ZedUpdateRequiredError {
2023                        minimum_version: minimum_required_version
2024                    }
2025                );
2026            }
2027
2028            if response.status().is_success() {
2029                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2030
2031                let mut body = Vec::new();
2032                response.body_mut().read_to_end(&mut body).await?;
2033                return Ok((serde_json::from_slice(&body)?, usage));
2034            } else if !did_retry
2035                && response
2036                    .headers()
2037                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2038                    .is_some()
2039            {
2040                did_retry = true;
2041                token = llm_token.refresh(&client).await?;
2042            } else {
2043                let mut body = String::new();
2044                response.body_mut().read_to_string(&mut body).await?;
2045                anyhow::bail!(
2046                    "Request failed with status: {:?}\nBody: {}",
2047                    response.status(),
2048                    body
2049                );
2050            }
2051        }
2052    }
2053
2054    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2055    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2056
2057    // Refresh the related excerpts when the user just beguns editing after
2058    // an idle period, and after they pause editing.
2059    fn refresh_context_if_needed(
2060        &mut self,
2061        project: &Entity<Project>,
2062        buffer: &Entity<language::Buffer>,
2063        cursor_position: language::Anchor,
2064        cx: &mut Context<Self>,
2065    ) {
2066        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2067            return;
2068        }
2069
2070        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2071            return;
2072        };
2073
2074        let now = Instant::now();
2075        let was_idle = zeta_project
2076            .refresh_context_timestamp
2077            .map_or(true, |timestamp| {
2078                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2079            });
2080        zeta_project.refresh_context_timestamp = Some(now);
2081        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2082            let buffer = buffer.clone();
2083            let project = project.clone();
2084            async move |this, cx| {
2085                if was_idle {
2086                    log::debug!("refetching edit prediction context after idle");
2087                } else {
2088                    cx.background_executor()
2089                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2090                        .await;
2091                    log::debug!("refetching edit prediction context after pause");
2092                }
2093                this.update(cx, |this, cx| {
2094                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2095
2096                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2097                        zeta_project.refresh_context_task = Some(task.log_err());
2098                    };
2099                })
2100                .ok()
2101            }
2102        }));
2103    }
2104
2105    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2106    // and avoid spawning more than one concurrent task.
2107    pub fn refresh_context(
2108        &mut self,
2109        project: Entity<Project>,
2110        buffer: Entity<language::Buffer>,
2111        cursor_position: language::Anchor,
2112        cx: &mut Context<Self>,
2113    ) -> Task<Result<()>> {
2114        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2115            return Task::ready(anyhow::Ok(()));
2116        };
2117
2118        let ContextMode::Agentic(options) = &self.options().context else {
2119            return Task::ready(anyhow::Ok(()));
2120        };
2121
2122        let snapshot = buffer.read(cx).snapshot();
2123        let cursor_point = cursor_position.to_point(&snapshot);
2124        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2125            cursor_point,
2126            &snapshot,
2127            &options.excerpt,
2128            None,
2129        ) else {
2130            return Task::ready(Ok(()));
2131        };
2132
2133        let app_version = AppVersion::global(cx);
2134        let client = self.client.clone();
2135        let llm_token = self.llm_token.clone();
2136        let debug_tx = self.debug_tx.clone();
2137        let current_file_path: Arc<Path> = snapshot
2138            .file()
2139            .map(|f| f.full_path(cx).into())
2140            .unwrap_or_else(|| Path::new("untitled").into());
2141
2142        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2143            predict_edits_v3::PlanContextRetrievalRequest {
2144                excerpt: cursor_excerpt.text(&snapshot).body,
2145                excerpt_path: current_file_path,
2146                excerpt_line_range: cursor_excerpt.line_range,
2147                cursor_file_max_row: Line(snapshot.max_point().row),
2148                events: zeta_project.events(cx),
2149            },
2150        ) {
2151            Ok(prompt) => prompt,
2152            Err(err) => {
2153                return Task::ready(Err(err));
2154            }
2155        };
2156
2157        if let Some(debug_tx) = &debug_tx {
2158            debug_tx
2159                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2160                    ZetaContextRetrievalStartedDebugInfo {
2161                        project: project.clone(),
2162                        timestamp: Instant::now(),
2163                        search_prompt: prompt.clone(),
2164                    },
2165                ))
2166                .ok();
2167        }
2168
2169        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2170            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2171                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2172            );
2173
2174            let description = schema
2175                .get("description")
2176                .and_then(|description| description.as_str())
2177                .unwrap()
2178                .to_string();
2179
2180            (schema.into(), description)
2181        });
2182
2183        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2184
2185        let request = open_ai::Request {
2186            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2187            messages: vec![open_ai::RequestMessage::User {
2188                content: open_ai::MessageContent::Plain(prompt),
2189            }],
2190            stream: false,
2191            max_completion_tokens: None,
2192            stop: Default::default(),
2193            temperature: 0.7,
2194            tool_choice: None,
2195            parallel_tool_calls: None,
2196            tools: vec![open_ai::ToolDefinition::Function {
2197                function: FunctionDefinition {
2198                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2199                    description: Some(tool_description),
2200                    parameters: Some(tool_schema),
2201                },
2202            }],
2203            prompt_cache_key: None,
2204            reasoning_effort: None,
2205        };
2206
2207        #[cfg(feature = "eval-support")]
2208        let eval_cache = self.eval_cache.clone();
2209
2210        cx.spawn(async move |this, cx| {
2211            log::trace!("Sending search planning request");
2212            let response = Self::send_raw_llm_request(
2213                request,
2214                client,
2215                llm_token,
2216                app_version,
2217                #[cfg(feature = "eval-support")]
2218                eval_cache.clone(),
2219                #[cfg(feature = "eval-support")]
2220                EvalCacheEntryKind::Context,
2221            )
2222            .await;
2223            let mut response = Self::handle_api_response(&this, response, cx)?;
2224            log::trace!("Got search planning response");
2225
2226            let choice = response
2227                .choices
2228                .pop()
2229                .context("No choices in retrieval response")?;
2230            let open_ai::RequestMessage::Assistant {
2231                content: _,
2232                tool_calls,
2233            } = choice.message
2234            else {
2235                anyhow::bail!("Retrieval response didn't include an assistant message");
2236            };
2237
2238            let mut queries: Vec<SearchToolQuery> = Vec::new();
2239            for tool_call in tool_calls {
2240                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2241                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2242                    log::warn!(
2243                        "Context retrieval response tried to call an unknown tool: {}",
2244                        function.name
2245                    );
2246
2247                    continue;
2248                }
2249
2250                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2251                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2252                queries.extend(input.queries);
2253            }
2254
2255            if let Some(debug_tx) = &debug_tx {
2256                debug_tx
2257                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2258                        ZetaSearchQueryDebugInfo {
2259                            project: project.clone(),
2260                            timestamp: Instant::now(),
2261                            search_queries: queries.clone(),
2262                        },
2263                    ))
2264                    .ok();
2265            }
2266
2267            log::trace!("Running retrieval search: {queries:#?}");
2268
2269            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2270                queries,
2271                project.clone(),
2272                #[cfg(feature = "eval-support")]
2273                eval_cache,
2274                cx,
2275            )
2276            .await;
2277
2278            log::trace!("Search queries executed");
2279
2280            if let Some(debug_tx) = &debug_tx {
2281                debug_tx
2282                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2283                        ZetaContextRetrievalDebugInfo {
2284                            project: project.clone(),
2285                            timestamp: Instant::now(),
2286                        },
2287                    ))
2288                    .ok();
2289            }
2290
2291            this.update(cx, |this, _cx| {
2292                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2293                    return Ok(());
2294                };
2295                zeta_project.refresh_context_task.take();
2296                if let Some(debug_tx) = &this.debug_tx {
2297                    debug_tx
2298                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2299                            ZetaContextRetrievalDebugInfo {
2300                                project,
2301                                timestamp: Instant::now(),
2302                            },
2303                        ))
2304                        .ok();
2305                }
2306                match related_excerpts_result {
2307                    Ok(excerpts) => {
2308                        zeta_project.context = Some(excerpts);
2309                        Ok(())
2310                    }
2311                    Err(error) => Err(error),
2312                }
2313            })?
2314        })
2315    }
2316
2317    pub fn set_context(
2318        &mut self,
2319        project: Entity<Project>,
2320        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2321    ) {
2322        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2323            zeta_project.context = Some(context);
2324        }
2325    }
2326
2327    fn gather_nearby_diagnostics(
2328        cursor_offset: usize,
2329        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2330        snapshot: &BufferSnapshot,
2331        max_diagnostics_bytes: usize,
2332    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2333        // TODO: Could make this more efficient
2334        let mut diagnostic_groups = Vec::new();
2335        for (language_server_id, diagnostics) in diagnostic_sets {
2336            let mut groups = Vec::new();
2337            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2338            diagnostic_groups.extend(
2339                groups
2340                    .into_iter()
2341                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2342            );
2343        }
2344
2345        // sort by proximity to cursor
2346        diagnostic_groups.sort_by_key(|group| {
2347            let range = &group.entries[group.primary_ix].range;
2348            if range.start >= cursor_offset {
2349                range.start - cursor_offset
2350            } else if cursor_offset >= range.end {
2351                cursor_offset - range.end
2352            } else {
2353                (cursor_offset - range.start).min(range.end - cursor_offset)
2354            }
2355        });
2356
2357        let mut results = Vec::new();
2358        let mut diagnostic_groups_truncated = false;
2359        let mut diagnostics_byte_count = 0;
2360        for group in diagnostic_groups {
2361            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2362            diagnostics_byte_count += raw_value.get().len();
2363            if diagnostics_byte_count > max_diagnostics_bytes {
2364                diagnostic_groups_truncated = true;
2365                break;
2366            }
2367            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2368        }
2369
2370        (results, diagnostic_groups_truncated)
2371    }
2372
2373    // TODO: Dedupe with similar code in request_prediction?
2374    pub fn cloud_request_for_zeta_cli(
2375        &mut self,
2376        project: &Entity<Project>,
2377        buffer: &Entity<Buffer>,
2378        position: language::Anchor,
2379        cx: &mut Context<Self>,
2380    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2381        let project_state = self.projects.get(&project.entity_id());
2382
2383        let index_state = project_state.and_then(|state| {
2384            state
2385                .syntax_index
2386                .as_ref()
2387                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2388        });
2389        let options = self.options.clone();
2390        let snapshot = buffer.read(cx).snapshot();
2391        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2392            return Task::ready(Err(anyhow!("No file path for excerpt")));
2393        };
2394        let worktree_snapshots = project
2395            .read(cx)
2396            .worktrees(cx)
2397            .map(|worktree| worktree.read(cx).snapshot())
2398            .collect::<Vec<_>>();
2399
2400        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2401            let mut path = f.worktree.read(cx).absolutize(&f.path);
2402            if path.pop() { Some(path) } else { None }
2403        });
2404
2405        cx.background_spawn(async move {
2406            let index_state = if let Some(index_state) = index_state {
2407                Some(index_state.lock_owned().await)
2408            } else {
2409                None
2410            };
2411
2412            let cursor_point = position.to_point(&snapshot);
2413
2414            let debug_info = true;
2415            EditPredictionContext::gather_context(
2416                cursor_point,
2417                &snapshot,
2418                parent_abs_path.as_deref(),
2419                match &options.context {
2420                    ContextMode::Agentic(_) => {
2421                        // TODO
2422                        panic!("Llm mode not supported in zeta cli yet");
2423                    }
2424                    ContextMode::Syntax(edit_prediction_context_options) => {
2425                        edit_prediction_context_options
2426                    }
2427                },
2428                index_state.as_deref(),
2429            )
2430            .context("Failed to select excerpt")
2431            .map(|context| {
2432                make_syntax_context_cloud_request(
2433                    excerpt_path.into(),
2434                    context,
2435                    // TODO pass everything
2436                    Vec::new(),
2437                    false,
2438                    Vec::new(),
2439                    false,
2440                    None,
2441                    debug_info,
2442                    &worktree_snapshots,
2443                    index_state.as_deref(),
2444                    Some(options.max_prompt_bytes),
2445                    options.prompt_format,
2446                    PredictEditsRequestTrigger::Other,
2447                )
2448            })
2449        })
2450    }
2451
2452    pub fn wait_for_initial_indexing(
2453        &mut self,
2454        project: &Entity<Project>,
2455        cx: &mut Context<Self>,
2456    ) -> Task<Result<()>> {
2457        let zeta_project = self.get_or_init_zeta_project(project, cx);
2458        if let Some(syntax_index) = &zeta_project.syntax_index {
2459            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2460        } else {
2461            Task::ready(Ok(()))
2462        }
2463    }
2464
2465    fn is_file_open_source(
2466        &self,
2467        project: &Entity<Project>,
2468        file: &Arc<dyn File>,
2469        cx: &App,
2470    ) -> bool {
2471        if !file.is_local() || file.is_private() {
2472            return false;
2473        }
2474        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2475            return false;
2476        };
2477        zeta_project
2478            .license_detection_watchers
2479            .get(&file.worktree_id(cx))
2480            .as_ref()
2481            .is_some_and(|watcher| watcher.is_project_open_source())
2482    }
2483
2484    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2485        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2486    }
2487
2488    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2489        if !self.data_collection_choice.is_enabled() {
2490            return false;
2491        }
2492        events.iter().all(|event| {
2493            matches!(
2494                event.as_ref(),
2495                Event::BufferChange {
2496                    in_open_source_repo: true,
2497                    ..
2498                }
2499            )
2500        })
2501    }
2502
2503    fn load_data_collection_choice() -> DataCollectionChoice {
2504        let choice = KEY_VALUE_STORE
2505            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2506            .log_err()
2507            .flatten();
2508
2509        match choice.as_deref() {
2510            Some("true") => DataCollectionChoice::Enabled,
2511            Some("false") => DataCollectionChoice::Disabled,
2512            Some(_) => {
2513                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2514                DataCollectionChoice::NotAnswered
2515            }
2516            None => DataCollectionChoice::NotAnswered,
2517        }
2518    }
2519
2520    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2521        self.shown_predictions.iter()
2522    }
2523
2524    pub fn shown_completions_len(&self) -> usize {
2525        self.shown_predictions.len()
2526    }
2527
2528    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2529        self.rated_predictions.contains(id)
2530    }
2531
2532    pub fn rate_prediction(
2533        &mut self,
2534        prediction: &EditPrediction,
2535        rating: EditPredictionRating,
2536        feedback: String,
2537        cx: &mut Context<Self>,
2538    ) {
2539        self.rated_predictions.insert(prediction.id.clone());
2540        telemetry::event!(
2541            "Edit Prediction Rated",
2542            rating,
2543            inputs = prediction.inputs,
2544            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2545            feedback
2546        );
2547        self.client.telemetry().flush_events().detach();
2548        cx.notify();
2549    }
2550}
2551
2552pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2553    let choice = res.choices.pop()?;
2554    let output_text = match choice.message {
2555        open_ai::RequestMessage::Assistant {
2556            content: Some(open_ai::MessageContent::Plain(content)),
2557            ..
2558        } => content,
2559        open_ai::RequestMessage::Assistant {
2560            content: Some(open_ai::MessageContent::Multipart(mut content)),
2561            ..
2562        } => {
2563            if content.is_empty() {
2564                log::error!("No output from Baseten completion response");
2565                return None;
2566            }
2567
2568            match content.remove(0) {
2569                open_ai::MessagePart::Text { text } => text,
2570                open_ai::MessagePart::Image { .. } => {
2571                    log::error!("Expected text, got an image");
2572                    return None;
2573                }
2574            }
2575        }
2576        _ => {
2577            log::error!("Invalid response message: {:?}", choice.message);
2578            return None;
2579        }
2580    };
2581    Some(output_text)
2582}
2583
2584#[derive(Error, Debug)]
2585#[error(
2586    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2587)]
2588pub struct ZedUpdateRequiredError {
2589    minimum_version: Version,
2590}
2591
2592fn make_syntax_context_cloud_request(
2593    excerpt_path: Arc<Path>,
2594    context: EditPredictionContext,
2595    events: Vec<Arc<predict_edits_v3::Event>>,
2596    can_collect_data: bool,
2597    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2598    diagnostic_groups_truncated: bool,
2599    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2600    debug_info: bool,
2601    worktrees: &Vec<worktree::Snapshot>,
2602    index_state: Option<&SyntaxIndexState>,
2603    prompt_max_bytes: Option<usize>,
2604    prompt_format: PromptFormat,
2605    trigger: PredictEditsRequestTrigger,
2606) -> predict_edits_v3::PredictEditsRequest {
2607    let mut signatures = Vec::new();
2608    let mut declaration_to_signature_index = HashMap::default();
2609    let mut referenced_declarations = Vec::new();
2610
2611    for snippet in context.declarations {
2612        let project_entry_id = snippet.declaration.project_entry_id();
2613        let Some(path) = worktrees.iter().find_map(|worktree| {
2614            worktree.entry_for_id(project_entry_id).map(|entry| {
2615                let mut full_path = RelPathBuf::new();
2616                full_path.push(worktree.root_name());
2617                full_path.push(&entry.path);
2618                full_path
2619            })
2620        }) else {
2621            continue;
2622        };
2623
2624        let parent_index = index_state.and_then(|index_state| {
2625            snippet.declaration.parent().and_then(|parent| {
2626                add_signature(
2627                    parent,
2628                    &mut declaration_to_signature_index,
2629                    &mut signatures,
2630                    index_state,
2631                )
2632            })
2633        });
2634
2635        let (text, text_is_truncated) = snippet.declaration.item_text();
2636        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2637            path: path.as_std_path().into(),
2638            text: text.into(),
2639            range: snippet.declaration.item_line_range(),
2640            text_is_truncated,
2641            signature_range: snippet.declaration.signature_range_in_item_text(),
2642            parent_index,
2643            signature_score: snippet.score(DeclarationStyle::Signature),
2644            declaration_score: snippet.score(DeclarationStyle::Declaration),
2645            score_components: snippet.components,
2646        });
2647    }
2648
2649    let excerpt_parent = index_state.and_then(|index_state| {
2650        context
2651            .excerpt
2652            .parent_declarations
2653            .last()
2654            .and_then(|(parent, _)| {
2655                add_signature(
2656                    *parent,
2657                    &mut declaration_to_signature_index,
2658                    &mut signatures,
2659                    index_state,
2660                )
2661            })
2662    });
2663
2664    predict_edits_v3::PredictEditsRequest {
2665        excerpt_path,
2666        excerpt: context.excerpt_text.body,
2667        excerpt_line_range: context.excerpt.line_range,
2668        excerpt_range: context.excerpt.range,
2669        cursor_point: predict_edits_v3::Point {
2670            line: predict_edits_v3::Line(context.cursor_point.row),
2671            column: context.cursor_point.column,
2672        },
2673        referenced_declarations,
2674        included_files: vec![],
2675        signatures,
2676        excerpt_parent,
2677        events,
2678        can_collect_data,
2679        diagnostic_groups,
2680        diagnostic_groups_truncated,
2681        git_info,
2682        debug_info,
2683        prompt_max_bytes,
2684        prompt_format,
2685        trigger,
2686    }
2687}
2688
2689fn add_signature(
2690    declaration_id: DeclarationId,
2691    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2692    signatures: &mut Vec<Signature>,
2693    index: &SyntaxIndexState,
2694) -> Option<usize> {
2695    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2696        return Some(*signature_index);
2697    }
2698    let Some(parent_declaration) = index.declaration(declaration_id) else {
2699        log::error!("bug: missing parent declaration");
2700        return None;
2701    };
2702    let parent_index = parent_declaration.parent().and_then(|parent| {
2703        add_signature(parent, declaration_to_signature_index, signatures, index)
2704    });
2705    let (text, text_is_truncated) = parent_declaration.signature_text();
2706    let signature_index = signatures.len();
2707    signatures.push(Signature {
2708        text: text.into(),
2709        text_is_truncated,
2710        parent_index,
2711        range: parent_declaration.signature_line_range(),
2712    });
2713    declaration_to_signature_index.insert(declaration_id, signature_index);
2714    Some(signature_index)
2715}
2716
2717#[cfg(feature = "eval-support")]
2718pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2719
2720#[cfg(feature = "eval-support")]
2721#[derive(Debug, Clone, Copy, PartialEq)]
2722pub enum EvalCacheEntryKind {
2723    Context,
2724    Search,
2725    Prediction,
2726}
2727
2728#[cfg(feature = "eval-support")]
2729impl std::fmt::Display for EvalCacheEntryKind {
2730    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2731        match self {
2732            EvalCacheEntryKind::Search => write!(f, "search"),
2733            EvalCacheEntryKind::Context => write!(f, "context"),
2734            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2735        }
2736    }
2737}
2738
2739#[cfg(feature = "eval-support")]
2740pub trait EvalCache: Send + Sync {
2741    fn read(&self, key: EvalCacheKey) -> Option<String>;
2742    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2743}
2744
2745#[derive(Debug, Clone, Copy)]
2746pub enum DataCollectionChoice {
2747    NotAnswered,
2748    Enabled,
2749    Disabled,
2750}
2751
2752impl DataCollectionChoice {
2753    pub fn is_enabled(self) -> bool {
2754        match self {
2755            Self::Enabled => true,
2756            Self::NotAnswered | Self::Disabled => false,
2757        }
2758    }
2759
2760    pub fn is_answered(self) -> bool {
2761        match self {
2762            Self::Enabled | Self::Disabled => true,
2763            Self::NotAnswered => false,
2764        }
2765    }
2766
2767    #[must_use]
2768    pub fn toggle(&self) -> DataCollectionChoice {
2769        match self {
2770            Self::Enabled => Self::Disabled,
2771            Self::Disabled => Self::Enabled,
2772            Self::NotAnswered => Self::Enabled,
2773        }
2774    }
2775}
2776
2777impl From<bool> for DataCollectionChoice {
2778    fn from(value: bool) -> Self {
2779        match value {
2780            true => DataCollectionChoice::Enabled,
2781            false => DataCollectionChoice::Disabled,
2782        }
2783    }
2784}
2785
2786struct ZedPredictUpsell;
2787
2788impl Dismissable for ZedPredictUpsell {
2789    const KEY: &'static str = "dismissed-edit-predict-upsell";
2790
2791    fn dismissed() -> bool {
2792        // To make this backwards compatible with older versions of Zed, we
2793        // check if the user has seen the previous Edit Prediction Onboarding
2794        // before, by checking the data collection choice which was written to
2795        // the database once the user clicked on "Accept and Enable"
2796        if KEY_VALUE_STORE
2797            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2798            .log_err()
2799            .is_some_and(|s| s.is_some())
2800        {
2801            return true;
2802        }
2803
2804        KEY_VALUE_STORE
2805            .read_kvp(Self::KEY)
2806            .log_err()
2807            .is_some_and(|s| s.is_some())
2808    }
2809}
2810
2811pub fn should_show_upsell_modal() -> bool {
2812    !ZedPredictUpsell::dismissed()
2813}
2814
2815pub fn init(cx: &mut App) {
2816    feature_gate_predict_edits_actions(cx);
2817
2818    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2819        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2820            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2821                RatePredictionsModal::toggle(workspace, window, cx);
2822            }
2823        });
2824
2825        workspace.register_action(
2826            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2827                ZedPredictModal::toggle(
2828                    workspace,
2829                    workspace.user_store().clone(),
2830                    workspace.client().clone(),
2831                    window,
2832                    cx,
2833                )
2834            },
2835        );
2836
2837        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2838            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2839                settings
2840                    .project
2841                    .all_languages
2842                    .features
2843                    .get_or_insert_default()
2844                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2845            });
2846        });
2847    })
2848    .detach();
2849}
2850
2851fn feature_gate_predict_edits_actions(cx: &mut App) {
2852    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2853    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2854    let zeta_all_action_types = [
2855        TypeId::of::<RateCompletions>(),
2856        TypeId::of::<ResetOnboarding>(),
2857        zed_actions::OpenZedPredictOnboarding.type_id(),
2858        TypeId::of::<ClearHistory>(),
2859        TypeId::of::<ThumbsUpActivePrediction>(),
2860        TypeId::of::<ThumbsDownActivePrediction>(),
2861        TypeId::of::<NextEdit>(),
2862        TypeId::of::<PreviousEdit>(),
2863    ];
2864
2865    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2866        filter.hide_action_types(&rate_completion_action_types);
2867        filter.hide_action_types(&reset_onboarding_action_types);
2868        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2869    });
2870
2871    cx.observe_global::<SettingsStore>(move |cx| {
2872        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2873        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2874
2875        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2876            if is_ai_disabled {
2877                filter.hide_action_types(&zeta_all_action_types);
2878            } else if has_feature_flag {
2879                filter.show_action_types(&rate_completion_action_types);
2880            } else {
2881                filter.hide_action_types(&rate_completion_action_types);
2882            }
2883        });
2884    })
2885    .detach();
2886
2887    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2888        if !DisableAiSettings::get_global(cx).disable_ai {
2889            if is_enabled {
2890                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2891                    filter.show_action_types(&rate_completion_action_types);
2892                });
2893            } else {
2894                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2895                    filter.hide_action_types(&rate_completion_action_types);
2896                });
2897            }
2898        }
2899    })
2900    .detach();
2901}
2902
2903#[cfg(test)]
2904mod tests {
2905    use std::{path::Path, sync::Arc};
2906
2907    use client::UserStore;
2908    use clock::FakeSystemClock;
2909    use cloud_llm_client::{
2910        EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2911    };
2912    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2913    use futures::{
2914        AsyncReadExt, StreamExt,
2915        channel::{mpsc, oneshot},
2916    };
2917    use gpui::{
2918        Entity, TestAppContext,
2919        http_client::{FakeHttpClient, Response},
2920        prelude::*,
2921    };
2922    use indoc::indoc;
2923    use language::OffsetRangeExt as _;
2924    use open_ai::Usage;
2925    use pretty_assertions::{assert_eq, assert_matches};
2926    use project::{FakeFs, Project};
2927    use serde_json::json;
2928    use settings::SettingsStore;
2929    use util::path;
2930    use uuid::Uuid;
2931
2932    use crate::{BufferEditPrediction, Zeta};
2933
2934    #[gpui::test]
2935    async fn test_current_state(cx: &mut TestAppContext) {
2936        let (zeta, mut requests) = init_test(cx);
2937        let fs = FakeFs::new(cx.executor());
2938        fs.insert_tree(
2939            "/root",
2940            json!({
2941                "1.txt": "Hello!\nHow\nBye\n",
2942                "2.txt": "Hola!\nComo\nAdios\n"
2943            }),
2944        )
2945        .await;
2946        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2947
2948        zeta.update(cx, |zeta, cx| {
2949            zeta.register_project(&project, cx);
2950        });
2951
2952        let buffer1 = project
2953            .update(cx, |project, cx| {
2954                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2955                project.open_buffer(path, cx)
2956            })
2957            .await
2958            .unwrap();
2959        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2960        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2961
2962        // Prediction for current file
2963
2964        zeta.update(cx, |zeta, cx| {
2965            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2966        });
2967        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2968
2969        respond_tx
2970            .send(model_response(indoc! {r"
2971                --- a/root/1.txt
2972                +++ b/root/1.txt
2973                @@ ... @@
2974                 Hello!
2975                -How
2976                +How are you?
2977                 Bye
2978            "}))
2979            .unwrap();
2980
2981        cx.run_until_parked();
2982
2983        zeta.read_with(cx, |zeta, cx| {
2984            let prediction = zeta
2985                .current_prediction_for_buffer(&buffer1, &project, cx)
2986                .unwrap();
2987            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2988        });
2989
2990        // Context refresh
2991        let refresh_task = zeta.update(cx, |zeta, cx| {
2992            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2993        });
2994        let (_request, respond_tx) = requests.predict.next().await.unwrap();
2995        respond_tx
2996            .send(open_ai::Response {
2997                id: Uuid::new_v4().to_string(),
2998                object: "response".into(),
2999                created: 0,
3000                model: "model".into(),
3001                choices: vec![open_ai::Choice {
3002                    index: 0,
3003                    message: open_ai::RequestMessage::Assistant {
3004                        content: None,
3005                        tool_calls: vec![open_ai::ToolCall {
3006                            id: "search".into(),
3007                            content: open_ai::ToolCallContent::Function {
3008                                function: open_ai::FunctionContent {
3009                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3010                                        .to_string(),
3011                                    arguments: serde_json::to_string(&SearchToolInput {
3012                                        queries: Box::new([SearchToolQuery {
3013                                            glob: "root/2.txt".to_string(),
3014                                            syntax_node: vec![],
3015                                            content: Some(".".into()),
3016                                        }]),
3017                                    })
3018                                    .unwrap(),
3019                                },
3020                            },
3021                        }],
3022                    },
3023                    finish_reason: None,
3024                }],
3025                usage: Usage {
3026                    prompt_tokens: 0,
3027                    completion_tokens: 0,
3028                    total_tokens: 0,
3029                },
3030            })
3031            .unwrap();
3032        refresh_task.await.unwrap();
3033
3034        zeta.update(cx, |zeta, cx| {
3035            zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3036        });
3037
3038        // Prediction for another file
3039        zeta.update(cx, |zeta, cx| {
3040            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3041        });
3042        let (_request, respond_tx) = requests.predict.next().await.unwrap();
3043        respond_tx
3044            .send(model_response(indoc! {r#"
3045                --- a/root/2.txt
3046                +++ b/root/2.txt
3047                 Hola!
3048                -Como
3049                +Como estas?
3050                 Adios
3051            "#}))
3052            .unwrap();
3053        cx.run_until_parked();
3054
3055        zeta.read_with(cx, |zeta, cx| {
3056            let prediction = zeta
3057                .current_prediction_for_buffer(&buffer1, &project, cx)
3058                .unwrap();
3059            assert_matches!(
3060                prediction,
3061                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3062            );
3063        });
3064
3065        let buffer2 = project
3066            .update(cx, |project, cx| {
3067                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3068                project.open_buffer(path, cx)
3069            })
3070            .await
3071            .unwrap();
3072
3073        zeta.read_with(cx, |zeta, cx| {
3074            let prediction = zeta
3075                .current_prediction_for_buffer(&buffer2, &project, cx)
3076                .unwrap();
3077            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3078        });
3079    }
3080
3081    #[gpui::test]
3082    async fn test_simple_request(cx: &mut TestAppContext) {
3083        let (zeta, mut requests) = init_test(cx);
3084        let fs = FakeFs::new(cx.executor());
3085        fs.insert_tree(
3086            "/root",
3087            json!({
3088                "foo.md":  "Hello!\nHow\nBye\n"
3089            }),
3090        )
3091        .await;
3092        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3093
3094        let buffer = project
3095            .update(cx, |project, cx| {
3096                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3097                project.open_buffer(path, cx)
3098            })
3099            .await
3100            .unwrap();
3101        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3102        let position = snapshot.anchor_before(language::Point::new(1, 3));
3103
3104        let prediction_task = zeta.update(cx, |zeta, cx| {
3105            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3106        });
3107
3108        let (_, respond_tx) = requests.predict.next().await.unwrap();
3109
3110        // TODO Put back when we have a structured request again
3111        // assert_eq!(
3112        //     request.excerpt_path.as_ref(),
3113        //     Path::new(path!("root/foo.md"))
3114        // );
3115        // assert_eq!(
3116        //     request.cursor_point,
3117        //     Point {
3118        //         line: Line(1),
3119        //         column: 3
3120        //     }
3121        // );
3122
3123        respond_tx
3124            .send(model_response(indoc! { r"
3125                --- a/root/foo.md
3126                +++ b/root/foo.md
3127                @@ ... @@
3128                 Hello!
3129                -How
3130                +How are you?
3131                 Bye
3132            "}))
3133            .unwrap();
3134
3135        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3136
3137        assert_eq!(prediction.edits.len(), 1);
3138        assert_eq!(
3139            prediction.edits[0].0.to_point(&snapshot).start,
3140            language::Point::new(1, 3)
3141        );
3142        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3143    }
3144
3145    #[gpui::test]
3146    async fn test_request_events(cx: &mut TestAppContext) {
3147        let (zeta, mut requests) = init_test(cx);
3148        let fs = FakeFs::new(cx.executor());
3149        fs.insert_tree(
3150            "/root",
3151            json!({
3152                "foo.md": "Hello!\n\nBye\n"
3153            }),
3154        )
3155        .await;
3156        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3157
3158        let buffer = project
3159            .update(cx, |project, cx| {
3160                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3161                project.open_buffer(path, cx)
3162            })
3163            .await
3164            .unwrap();
3165
3166        zeta.update(cx, |zeta, cx| {
3167            zeta.register_buffer(&buffer, &project, cx);
3168        });
3169
3170        buffer.update(cx, |buffer, cx| {
3171            buffer.edit(vec![(7..7, "How")], None, cx);
3172        });
3173
3174        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3175        let position = snapshot.anchor_before(language::Point::new(1, 3));
3176
3177        let prediction_task = zeta.update(cx, |zeta, cx| {
3178            zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3179        });
3180
3181        let (request, respond_tx) = requests.predict.next().await.unwrap();
3182
3183        let prompt = prompt_from_request(&request);
3184        assert!(
3185            prompt.contains(indoc! {"
3186            --- a/root/foo.md
3187            +++ b/root/foo.md
3188            @@ -1,3 +1,3 @@
3189             Hello!
3190            -
3191            +How
3192             Bye
3193        "}),
3194            "{prompt}"
3195        );
3196
3197        respond_tx
3198            .send(model_response(indoc! {r#"
3199                --- a/root/foo.md
3200                +++ b/root/foo.md
3201                @@ ... @@
3202                 Hello!
3203                -How
3204                +How are you?
3205                 Bye
3206            "#}))
3207            .unwrap();
3208
3209        let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3210
3211        assert_eq!(prediction.edits.len(), 1);
3212        assert_eq!(
3213            prediction.edits[0].0.to_point(&snapshot).start,
3214            language::Point::new(1, 3)
3215        );
3216        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3217    }
3218
3219    #[gpui::test]
3220    async fn test_empty_prediction(cx: &mut TestAppContext) {
3221        let (zeta, mut requests) = init_test(cx);
3222        let fs = FakeFs::new(cx.executor());
3223        fs.insert_tree(
3224            "/root",
3225            json!({
3226                "foo.md":  "Hello!\nHow\nBye\n"
3227            }),
3228        )
3229        .await;
3230        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3231
3232        let buffer = project
3233            .update(cx, |project, cx| {
3234                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3235                project.open_buffer(path, cx)
3236            })
3237            .await
3238            .unwrap();
3239        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3240        let position = snapshot.anchor_before(language::Point::new(1, 3));
3241
3242        zeta.update(cx, |zeta, cx| {
3243            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3244        });
3245
3246        const NO_OP_DIFF: &str = indoc! { r"
3247            --- a/root/foo.md
3248            +++ b/root/foo.md
3249            @@ ... @@
3250             Hello!
3251            -How
3252            +How
3253             Bye
3254        "};
3255
3256        let (_, respond_tx) = requests.predict.next().await.unwrap();
3257        let response = model_response(NO_OP_DIFF);
3258        let id = response.id.clone();
3259        respond_tx.send(response).unwrap();
3260
3261        cx.run_until_parked();
3262
3263        zeta.read_with(cx, |zeta, cx| {
3264            assert!(
3265                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3266                    .is_none()
3267            );
3268        });
3269
3270        // prediction is reported as rejected
3271        let (reject_request, _) = requests.reject.next().await.unwrap();
3272
3273        assert_eq!(
3274            &reject_request.rejections,
3275            &[EditPredictionRejection {
3276                request_id: id,
3277                reason: EditPredictionRejectReason::Empty,
3278                was_shown: false
3279            }]
3280        );
3281    }
3282
3283    #[gpui::test]
3284    async fn test_interpolated_empty(cx: &mut TestAppContext) {
3285        let (zeta, mut requests) = init_test(cx);
3286        let fs = FakeFs::new(cx.executor());
3287        fs.insert_tree(
3288            "/root",
3289            json!({
3290                "foo.md":  "Hello!\nHow\nBye\n"
3291            }),
3292        )
3293        .await;
3294        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3295
3296        let buffer = project
3297            .update(cx, |project, cx| {
3298                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3299                project.open_buffer(path, cx)
3300            })
3301            .await
3302            .unwrap();
3303        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3304        let position = snapshot.anchor_before(language::Point::new(1, 3));
3305
3306        zeta.update(cx, |zeta, cx| {
3307            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3308        });
3309
3310        let (_, respond_tx) = requests.predict.next().await.unwrap();
3311
3312        buffer.update(cx, |buffer, cx| {
3313            buffer.set_text("Hello!\nHow are you?\nBye", cx);
3314        });
3315
3316        let response = model_response(SIMPLE_DIFF);
3317        let id = response.id.clone();
3318        respond_tx.send(response).unwrap();
3319
3320        cx.run_until_parked();
3321
3322        zeta.read_with(cx, |zeta, cx| {
3323            assert!(
3324                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3325                    .is_none()
3326            );
3327        });
3328
3329        // prediction is reported as rejected
3330        let (reject_request, _) = requests.reject.next().await.unwrap();
3331
3332        assert_eq!(
3333            &reject_request.rejections,
3334            &[EditPredictionRejection {
3335                request_id: id,
3336                reason: EditPredictionRejectReason::InterpolatedEmpty,
3337                was_shown: false
3338            }]
3339        );
3340    }
3341
3342    const SIMPLE_DIFF: &str = indoc! { r"
3343        --- a/root/foo.md
3344        +++ b/root/foo.md
3345        @@ ... @@
3346         Hello!
3347        -How
3348        +How are you?
3349         Bye
3350    "};
3351
3352    #[gpui::test]
3353    async fn test_replace_current(cx: &mut TestAppContext) {
3354        let (zeta, mut requests) = init_test(cx);
3355        let fs = FakeFs::new(cx.executor());
3356        fs.insert_tree(
3357            "/root",
3358            json!({
3359                "foo.md":  "Hello!\nHow\nBye\n"
3360            }),
3361        )
3362        .await;
3363        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3364
3365        let buffer = project
3366            .update(cx, |project, cx| {
3367                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3368                project.open_buffer(path, cx)
3369            })
3370            .await
3371            .unwrap();
3372        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3373        let position = snapshot.anchor_before(language::Point::new(1, 3));
3374
3375        zeta.update(cx, |zeta, cx| {
3376            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3377        });
3378
3379        let (_, respond_tx) = requests.predict.next().await.unwrap();
3380        let first_response = model_response(SIMPLE_DIFF);
3381        let first_id = first_response.id.clone();
3382        respond_tx.send(first_response).unwrap();
3383
3384        cx.run_until_parked();
3385
3386        zeta.read_with(cx, |zeta, cx| {
3387            assert_eq!(
3388                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3389                    .unwrap()
3390                    .id
3391                    .0,
3392                first_id
3393            );
3394        });
3395
3396        // a second request is triggered
3397        zeta.update(cx, |zeta, cx| {
3398            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3399        });
3400
3401        let (_, respond_tx) = requests.predict.next().await.unwrap();
3402        let second_response = model_response(SIMPLE_DIFF);
3403        let second_id = second_response.id.clone();
3404        respond_tx.send(second_response).unwrap();
3405
3406        cx.run_until_parked();
3407
3408        zeta.read_with(cx, |zeta, cx| {
3409            // second replaces first
3410            assert_eq!(
3411                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3412                    .unwrap()
3413                    .id
3414                    .0,
3415                second_id
3416            );
3417        });
3418
3419        // first is reported as replaced
3420        let (reject_request, _) = requests.reject.next().await.unwrap();
3421
3422        assert_eq!(
3423            &reject_request.rejections,
3424            &[EditPredictionRejection {
3425                request_id: first_id,
3426                reason: EditPredictionRejectReason::Replaced,
3427                was_shown: false
3428            }]
3429        );
3430    }
3431
3432    #[gpui::test]
3433    async fn test_current_preferred(cx: &mut TestAppContext) {
3434        let (zeta, mut requests) = init_test(cx);
3435        let fs = FakeFs::new(cx.executor());
3436        fs.insert_tree(
3437            "/root",
3438            json!({
3439                "foo.md":  "Hello!\nHow\nBye\n"
3440            }),
3441        )
3442        .await;
3443        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3444
3445        let buffer = project
3446            .update(cx, |project, cx| {
3447                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3448                project.open_buffer(path, cx)
3449            })
3450            .await
3451            .unwrap();
3452        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3453        let position = snapshot.anchor_before(language::Point::new(1, 3));
3454
3455        zeta.update(cx, |zeta, cx| {
3456            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3457        });
3458
3459        let (_, respond_tx) = requests.predict.next().await.unwrap();
3460        let first_response = model_response(SIMPLE_DIFF);
3461        let first_id = first_response.id.clone();
3462        respond_tx.send(first_response).unwrap();
3463
3464        cx.run_until_parked();
3465
3466        zeta.read_with(cx, |zeta, cx| {
3467            assert_eq!(
3468                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3469                    .unwrap()
3470                    .id
3471                    .0,
3472                first_id
3473            );
3474        });
3475
3476        // a second request is triggered
3477        zeta.update(cx, |zeta, cx| {
3478            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3479        });
3480
3481        let (_, respond_tx) = requests.predict.next().await.unwrap();
3482        // worse than current prediction
3483        let second_response = model_response(indoc! { r"
3484            --- a/root/foo.md
3485            +++ b/root/foo.md
3486            @@ ... @@
3487             Hello!
3488            -How
3489            +How are
3490             Bye
3491        "});
3492        let second_id = second_response.id.clone();
3493        respond_tx.send(second_response).unwrap();
3494
3495        cx.run_until_parked();
3496
3497        zeta.read_with(cx, |zeta, cx| {
3498            // first is preferred over second
3499            assert_eq!(
3500                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3501                    .unwrap()
3502                    .id
3503                    .0,
3504                first_id
3505            );
3506        });
3507
3508        // second is reported as rejected
3509        let (reject_request, _) = requests.reject.next().await.unwrap();
3510
3511        assert_eq!(
3512            &reject_request.rejections,
3513            &[EditPredictionRejection {
3514                request_id: second_id,
3515                reason: EditPredictionRejectReason::CurrentPreferred,
3516                was_shown: false
3517            }]
3518        );
3519    }
3520
3521    #[gpui::test]
3522    async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3523        let (zeta, mut requests) = init_test(cx);
3524        let fs = FakeFs::new(cx.executor());
3525        fs.insert_tree(
3526            "/root",
3527            json!({
3528                "foo.md":  "Hello!\nHow\nBye\n"
3529            }),
3530        )
3531        .await;
3532        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3533
3534        let buffer = project
3535            .update(cx, |project, cx| {
3536                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3537                project.open_buffer(path, cx)
3538            })
3539            .await
3540            .unwrap();
3541        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3542        let position = snapshot.anchor_before(language::Point::new(1, 3));
3543
3544        zeta.update(cx, |zeta, cx| {
3545            // start two refresh tasks
3546            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3547
3548            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3549        });
3550
3551        let (_, respond_first) = requests.predict.next().await.unwrap();
3552        let (_, respond_second) = requests.predict.next().await.unwrap();
3553
3554        // wait for throttle
3555        cx.run_until_parked();
3556
3557        // second responds first
3558        let second_response = model_response(SIMPLE_DIFF);
3559        let second_id = second_response.id.clone();
3560        respond_second.send(second_response).unwrap();
3561
3562        cx.run_until_parked();
3563
3564        zeta.read_with(cx, |zeta, cx| {
3565            // current prediction is second
3566            assert_eq!(
3567                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3568                    .unwrap()
3569                    .id
3570                    .0,
3571                second_id
3572            );
3573        });
3574
3575        let first_response = model_response(SIMPLE_DIFF);
3576        let first_id = first_response.id.clone();
3577        respond_first.send(first_response).unwrap();
3578
3579        cx.run_until_parked();
3580
3581        zeta.read_with(cx, |zeta, cx| {
3582            // current prediction is still second, since first was cancelled
3583            assert_eq!(
3584                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3585                    .unwrap()
3586                    .id
3587                    .0,
3588                second_id
3589            );
3590        });
3591
3592        // first is reported as rejected
3593        let (reject_request, _) = requests.reject.next().await.unwrap();
3594
3595        cx.run_until_parked();
3596
3597        assert_eq!(
3598            &reject_request.rejections,
3599            &[EditPredictionRejection {
3600                request_id: first_id,
3601                reason: EditPredictionRejectReason::Canceled,
3602                was_shown: false
3603            }]
3604        );
3605    }
3606
3607    #[gpui::test]
3608    async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3609        let (zeta, mut requests) = init_test(cx);
3610        let fs = FakeFs::new(cx.executor());
3611        fs.insert_tree(
3612            "/root",
3613            json!({
3614                "foo.md":  "Hello!\nHow\nBye\n"
3615            }),
3616        )
3617        .await;
3618        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3619
3620        let buffer = project
3621            .update(cx, |project, cx| {
3622                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3623                project.open_buffer(path, cx)
3624            })
3625            .await
3626            .unwrap();
3627        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3628        let position = snapshot.anchor_before(language::Point::new(1, 3));
3629
3630        zeta.update(cx, |zeta, cx| {
3631            // start two refresh tasks
3632            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3633            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3634        });
3635
3636        // wait for throttle, so requests are sent
3637        cx.run_until_parked();
3638
3639        let (_, respond_first) = requests.predict.next().await.unwrap();
3640        let (_, respond_second) = requests.predict.next().await.unwrap();
3641
3642        zeta.update(cx, |zeta, cx| {
3643            // start a third request
3644            zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3645
3646            // 2 are pending, so 2nd is cancelled
3647            assert_eq!(
3648                zeta.get_or_init_zeta_project(&project, cx)
3649                    .cancelled_predictions
3650                    .iter()
3651                    .copied()
3652                    .collect::<Vec<_>>(),
3653                [1]
3654            );
3655        });
3656
3657        // wait for throttle
3658        cx.run_until_parked();
3659
3660        let (_, respond_third) = requests.predict.next().await.unwrap();
3661
3662        let first_response = model_response(SIMPLE_DIFF);
3663        let first_id = first_response.id.clone();
3664        respond_first.send(first_response).unwrap();
3665
3666        cx.run_until_parked();
3667
3668        zeta.read_with(cx, |zeta, cx| {
3669            // current prediction is first
3670            assert_eq!(
3671                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3672                    .unwrap()
3673                    .id
3674                    .0,
3675                first_id
3676            );
3677        });
3678
3679        let cancelled_response = model_response(SIMPLE_DIFF);
3680        let cancelled_id = cancelled_response.id.clone();
3681        respond_second.send(cancelled_response).unwrap();
3682
3683        cx.run_until_parked();
3684
3685        zeta.read_with(cx, |zeta, cx| {
3686            // current prediction is still first, since second was cancelled
3687            assert_eq!(
3688                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3689                    .unwrap()
3690                    .id
3691                    .0,
3692                first_id
3693            );
3694        });
3695
3696        let third_response = model_response(SIMPLE_DIFF);
3697        let third_response_id = third_response.id.clone();
3698        respond_third.send(third_response).unwrap();
3699
3700        cx.run_until_parked();
3701
3702        zeta.read_with(cx, |zeta, cx| {
3703            // third completes and replaces first
3704            assert_eq!(
3705                zeta.current_prediction_for_buffer(&buffer, &project, cx)
3706                    .unwrap()
3707                    .id
3708                    .0,
3709                third_response_id
3710            );
3711        });
3712
3713        // second is reported as rejected
3714        let (reject_request, _) = requests.reject.next().await.unwrap();
3715
3716        cx.run_until_parked();
3717
3718        assert_eq!(
3719            &reject_request.rejections,
3720            &[
3721                EditPredictionRejection {
3722                    request_id: cancelled_id,
3723                    reason: EditPredictionRejectReason::Canceled,
3724                    was_shown: false
3725                },
3726                EditPredictionRejection {
3727                    request_id: first_id,
3728                    reason: EditPredictionRejectReason::Replaced,
3729                    was_shown: false
3730                }
3731            ]
3732        );
3733    }
3734
3735    // Skipped until we start including diagnostics in prompt
3736    // #[gpui::test]
3737    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3738    //     let (zeta, mut req_rx) = init_test(cx);
3739    //     let fs = FakeFs::new(cx.executor());
3740    //     fs.insert_tree(
3741    //         "/root",
3742    //         json!({
3743    //             "foo.md": "Hello!\nBye"
3744    //         }),
3745    //     )
3746    //     .await;
3747    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3748
3749    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3750    //     let diagnostic = lsp::Diagnostic {
3751    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3752    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3753    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3754    //         ..Default::default()
3755    //     };
3756
3757    //     project.update(cx, |project, cx| {
3758    //         project.lsp_store().update(cx, |lsp_store, cx| {
3759    //             // Create some diagnostics
3760    //             lsp_store
3761    //                 .update_diagnostics(
3762    //                     LanguageServerId(0),
3763    //                     lsp::PublishDiagnosticsParams {
3764    //                         uri: path_to_buffer_uri.clone(),
3765    //                         diagnostics: vec![diagnostic],
3766    //                         version: None,
3767    //                     },
3768    //                     None,
3769    //                     language::DiagnosticSourceKind::Pushed,
3770    //                     &[],
3771    //                     cx,
3772    //                 )
3773    //                 .unwrap();
3774    //         });
3775    //     });
3776
3777    //     let buffer = project
3778    //         .update(cx, |project, cx| {
3779    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3780    //             project.open_buffer(path, cx)
3781    //         })
3782    //         .await
3783    //         .unwrap();
3784
3785    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3786    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3787
3788    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3789    //         zeta.request_prediction(&project, &buffer, position, cx)
3790    //     });
3791
3792    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3793
3794    //     assert_eq!(request.diagnostic_groups.len(), 1);
3795    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3796    //         .unwrap();
3797    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3798    //     assert_eq!(
3799    //         value,
3800    //         json!({
3801    //             "entries": [{
3802    //                 "range": {
3803    //                     "start": 8,
3804    //                     "end": 10
3805    //                 },
3806    //                 "diagnostic": {
3807    //                     "source": null,
3808    //                     "code": null,
3809    //                     "code_description": null,
3810    //                     "severity": 1,
3811    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3812    //                     "markdown": null,
3813    //                     "group_id": 0,
3814    //                     "is_primary": true,
3815    //                     "is_disk_based": false,
3816    //                     "is_unnecessary": false,
3817    //                     "source_kind": "Pushed",
3818    //                     "data": null,
3819    //                     "underline": true
3820    //                 }
3821    //             }],
3822    //             "primary_ix": 0
3823    //         })
3824    //     );
3825    // }
3826
3827    fn model_response(text: &str) -> open_ai::Response {
3828        open_ai::Response {
3829            id: Uuid::new_v4().to_string(),
3830            object: "response".into(),
3831            created: 0,
3832            model: "model".into(),
3833            choices: vec![open_ai::Choice {
3834                index: 0,
3835                message: open_ai::RequestMessage::Assistant {
3836                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3837                    tool_calls: vec![],
3838                },
3839                finish_reason: None,
3840            }],
3841            usage: Usage {
3842                prompt_tokens: 0,
3843                completion_tokens: 0,
3844                total_tokens: 0,
3845            },
3846        }
3847    }
3848
3849    fn prompt_from_request(request: &open_ai::Request) -> &str {
3850        assert_eq!(request.messages.len(), 1);
3851        let open_ai::RequestMessage::User {
3852            content: open_ai::MessageContent::Plain(content),
3853            ..
3854        } = &request.messages[0]
3855        else {
3856            panic!(
3857                "Request does not have single user message of type Plain. {:#?}",
3858                request
3859            );
3860        };
3861        content
3862    }
3863
3864    struct RequestChannels {
3865        predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3866        reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3867    }
3868
3869    fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3870        cx.update(move |cx| {
3871            let settings_store = SettingsStore::test(cx);
3872            cx.set_global(settings_store);
3873            zlog::init_test();
3874
3875            let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3876            let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3877
3878            let http_client = FakeHttpClient::create({
3879                move |req| {
3880                    let uri = req.uri().path().to_string();
3881                    let mut body = req.into_body();
3882                    let predict_req_tx = predict_req_tx.clone();
3883                    let reject_req_tx = reject_req_tx.clone();
3884                    async move {
3885                        let resp = match uri.as_str() {
3886                            "/client/llm_tokens" => serde_json::to_string(&json!({
3887                                "token": "test"
3888                            }))
3889                            .unwrap(),
3890                            "/predict_edits/raw" => {
3891                                let mut buf = Vec::new();
3892                                body.read_to_end(&mut buf).await.ok();
3893                                let req = serde_json::from_slice(&buf).unwrap();
3894
3895                                let (res_tx, res_rx) = oneshot::channel();
3896                                predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3897                                serde_json::to_string(&res_rx.await?).unwrap()
3898                            }
3899                            "/predict_edits/reject" => {
3900                                let mut buf = Vec::new();
3901                                body.read_to_end(&mut buf).await.ok();
3902                                let req = serde_json::from_slice(&buf).unwrap();
3903
3904                                let (res_tx, res_rx) = oneshot::channel();
3905                                reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3906                                serde_json::to_string(&res_rx.await?).unwrap()
3907                            }
3908                            _ => {
3909                                panic!("Unexpected path: {}", uri)
3910                            }
3911                        };
3912
3913                        Ok(Response::builder().body(resp.into()).unwrap())
3914                    }
3915                }
3916            });
3917
3918            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3919            client.cloud_client().set_credentials(1, "test".into());
3920
3921            language_model::init(client.clone(), cx);
3922
3923            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3924            let zeta = Zeta::global(&client, &user_store, cx);
3925
3926            (
3927                zeta,
3928                RequestChannels {
3929                    predict: predict_req_rx,
3930                    reject: reject_req_rx,
3931                },
3932            )
3933        })
3934    }
3935}