zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
   7    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::{HashMap, HashSet};
  13use command_palette_hooks::CommandPaletteFilter;
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{
  16    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  17    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  18    SyntaxIndex, SyntaxIndexState,
  19};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  21use futures::channel::{mpsc, oneshot};
  22use futures::{AsyncReadExt as _, StreamExt as _};
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::{
  29    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  30};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, RefreshLlmTokenListener};
  33use open_ai::FunctionDefinition;
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  39use std::any::{Any as _, TypeId};
  40use std::collections::{VecDeque, hash_map};
  41use telemetry_events::EditPredictionRating;
  42use workspace::Workspace;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::rel_path::RelPathBuf;
  53use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  55
  56pub mod assemble_excerpts;
  57mod license_detection;
  58mod onboarding_modal;
  59mod prediction;
  60mod provider;
  61mod rate_prediction_modal;
  62pub mod retrieval_search;
  63mod sweep_ai;
  64pub mod udiff;
  65mod xml_edits;
  66pub mod zeta1;
  67
  68#[cfg(test)]
  69mod zeta_tests;
  70
  71use crate::assemble_excerpts::assemble_excerpts;
  72use crate::license_detection::LicenseDetectionWatcher;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76pub use crate::prediction::EditPredictionInputs;
  77use crate::rate_prediction_modal::{
  78    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  79    ThumbsUpActivePrediction,
  80};
  81use crate::sweep_ai::SweepAi;
  82use crate::zeta1::request_prediction_with_zeta1;
  83pub use provider::ZetaEditPredictionProvider;
  84
  85actions!(
  86    edit_prediction,
  87    [
  88        /// Resets the edit prediction onboarding state.
  89        ResetOnboarding,
  90        /// Opens the rate completions modal.
  91        RateCompletions,
  92        /// Clears the edit prediction history.
  93        ClearHistory,
  94    ]
  95);
  96
  97/// Maximum number of events to track.
  98const EVENT_COUNT_MAX: usize = 6;
  99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 100const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 101
 102pub struct SweepFeatureFlag;
 103
 104impl FeatureFlag for SweepFeatureFlag {
 105    const NAME: &str = "sweep-ai";
 106}
 107pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 108    max_bytes: 512,
 109    min_bytes: 128,
 110    target_before_cursor_over_total_bytes: 0.5,
 111};
 112
 113pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 114    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 115
 116pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 117    excerpt: DEFAULT_EXCERPT_OPTIONS,
 118};
 119
 120pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 121    EditPredictionContextOptions {
 122        use_imports: true,
 123        max_retrieved_declarations: 0,
 124        excerpt: DEFAULT_EXCERPT_OPTIONS,
 125        score: EditPredictionScoreOptions {
 126            omit_excerpt_overlaps: true,
 127        },
 128    };
 129
 130pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 131    context: DEFAULT_CONTEXT_OPTIONS,
 132    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 133    max_diagnostic_bytes: 2048,
 134    prompt_format: PromptFormat::DEFAULT,
 135    file_indexing_parallelism: 1,
 136    buffer_change_grouping_interval: Duration::from_secs(1),
 137};
 138
 139static USE_OLLAMA: LazyLock<bool> =
 140    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 141static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 142    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 143        "qwen3-coder:30b".to_string()
 144    } else {
 145        "yqvev8r3".to_string()
 146    })
 147});
 148static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 149    match env::var("ZED_ZETA2_MODEL").as_deref() {
 150        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 151        Ok(model) => model,
 152        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 153        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 154    }
 155    .to_string()
 156});
 157static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 158    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 159        if *USE_OLLAMA {
 160            Some("http://localhost:11434/v1/chat/completions".into())
 161        } else {
 162            None
 163        }
 164    })
 165});
 166
 167pub struct Zeta2FeatureFlag;
 168
 169impl FeatureFlag for Zeta2FeatureFlag {
 170    const NAME: &'static str = "zeta2";
 171
 172    fn enabled_for_staff() -> bool {
 173        true
 174    }
 175}
 176
 177#[derive(Clone)]
 178struct ZetaGlobal(Entity<Zeta>);
 179
 180impl Global for ZetaGlobal {}
 181
 182pub struct Zeta {
 183    client: Arc<Client>,
 184    user_store: Entity<UserStore>,
 185    llm_token: LlmApiToken,
 186    _llm_token_subscription: Subscription,
 187    projects: HashMap<EntityId, ZetaProject>,
 188    options: ZetaOptions,
 189    update_required: bool,
 190    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 191    #[cfg(feature = "eval-support")]
 192    eval_cache: Option<Arc<dyn EvalCache>>,
 193    edit_prediction_model: ZetaEditPredictionModel,
 194    sweep_ai: SweepAi,
 195    data_collection_choice: DataCollectionChoice,
 196    rejected_predictions: Vec<EditPredictionRejection>,
 197    reject_predictions_tx: mpsc::UnboundedSender<()>,
 198    reject_predictions_debounce_task: Option<Task<()>>,
 199    shown_predictions: VecDeque<EditPrediction>,
 200    rated_predictions: HashSet<EditPredictionId>,
 201}
 202
 203#[derive(Copy, Clone, Default, PartialEq, Eq)]
 204pub enum ZetaEditPredictionModel {
 205    #[default]
 206    Zeta1,
 207    Zeta2,
 208    Sweep,
 209}
 210
 211#[derive(Debug, Clone, PartialEq)]
 212pub struct ZetaOptions {
 213    pub context: ContextMode,
 214    pub max_prompt_bytes: usize,
 215    pub max_diagnostic_bytes: usize,
 216    pub prompt_format: predict_edits_v3::PromptFormat,
 217    pub file_indexing_parallelism: usize,
 218    pub buffer_change_grouping_interval: Duration,
 219}
 220
 221#[derive(Debug, Clone, PartialEq)]
 222pub enum ContextMode {
 223    Agentic(AgenticContextOptions),
 224    Syntax(EditPredictionContextOptions),
 225}
 226
 227#[derive(Debug, Clone, PartialEq)]
 228pub struct AgenticContextOptions {
 229    pub excerpt: EditPredictionExcerptOptions,
 230}
 231
 232impl ContextMode {
 233    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 234        match self {
 235            ContextMode::Agentic(options) => &options.excerpt,
 236            ContextMode::Syntax(options) => &options.excerpt,
 237        }
 238    }
 239}
 240
 241#[derive(Debug)]
 242pub enum ZetaDebugInfo {
 243    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 244    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 245    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 246    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 247    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 248}
 249
 250#[derive(Debug)]
 251pub struct ZetaContextRetrievalStartedDebugInfo {
 252    pub project: Entity<Project>,
 253    pub timestamp: Instant,
 254    pub search_prompt: String,
 255}
 256
 257#[derive(Debug)]
 258pub struct ZetaContextRetrievalDebugInfo {
 259    pub project: Entity<Project>,
 260    pub timestamp: Instant,
 261}
 262
 263#[derive(Debug)]
 264pub struct ZetaEditPredictionDebugInfo {
 265    pub inputs: EditPredictionInputs,
 266    pub retrieval_time: Duration,
 267    pub buffer: WeakEntity<Buffer>,
 268    pub position: language::Anchor,
 269    pub local_prompt: Result<String, String>,
 270    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 271}
 272
 273#[derive(Debug)]
 274pub struct ZetaSearchQueryDebugInfo {
 275    pub project: Entity<Project>,
 276    pub timestamp: Instant,
 277    pub search_queries: Vec<SearchToolQuery>,
 278}
 279
 280pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 281
 282struct ZetaProject {
 283    syntax_index: Option<Entity<SyntaxIndex>>,
 284    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 285    last_event: Option<LastEvent>,
 286    recent_paths: VecDeque<ProjectPath>,
 287    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 288    current_prediction: Option<CurrentEditPrediction>,
 289    next_pending_prediction_id: usize,
 290    pending_predictions: ArrayVec<PendingPrediction, 2>,
 291    last_prediction_refresh: Option<(EntityId, Instant)>,
 292    cancelled_predictions: HashSet<usize>,
 293    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 294    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 295    refresh_context_debounce_task: Option<Task<Option<()>>>,
 296    refresh_context_timestamp: Option<Instant>,
 297    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 298    _subscription: gpui::Subscription,
 299}
 300
 301impl ZetaProject {
 302    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 303        self.events
 304            .iter()
 305            .cloned()
 306            .chain(
 307                self.last_event
 308                    .as_ref()
 309                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 310            )
 311            .collect()
 312    }
 313}
 314
 315#[derive(Debug, Clone)]
 316struct CurrentEditPrediction {
 317    pub requested_by: PredictionRequestedBy,
 318    pub prediction: EditPrediction,
 319    pub was_shown: bool,
 320}
 321
 322impl CurrentEditPrediction {
 323    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 324        let Some(new_edits) = self
 325            .prediction
 326            .interpolate(&self.prediction.buffer.read(cx))
 327        else {
 328            return false;
 329        };
 330
 331        if self.prediction.buffer != old_prediction.prediction.buffer {
 332            return true;
 333        }
 334
 335        let Some(old_edits) = old_prediction
 336            .prediction
 337            .interpolate(&old_prediction.prediction.buffer.read(cx))
 338        else {
 339            return true;
 340        };
 341
 342        let requested_by_buffer_id = self.requested_by.buffer_id();
 343
 344        // This reduces the occurrence of UI thrash from replacing edits
 345        //
 346        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 347        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 348            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 349            && old_edits.len() == 1
 350            && new_edits.len() == 1
 351        {
 352            let (old_range, old_text) = &old_edits[0];
 353            let (new_range, new_text) = &new_edits[0];
 354            new_range == old_range && new_text.starts_with(old_text.as_ref())
 355        } else {
 356            true
 357        }
 358    }
 359}
 360
 361#[derive(Debug, Clone)]
 362enum PredictionRequestedBy {
 363    DiagnosticsUpdate,
 364    Buffer(EntityId),
 365}
 366
 367impl PredictionRequestedBy {
 368    pub fn buffer_id(&self) -> Option<EntityId> {
 369        match self {
 370            PredictionRequestedBy::DiagnosticsUpdate => None,
 371            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 372        }
 373    }
 374}
 375
 376struct PendingPrediction {
 377    id: usize,
 378    task: Task<Option<EditPredictionId>>,
 379}
 380
 381/// A prediction from the perspective of a buffer.
 382#[derive(Debug)]
 383enum BufferEditPrediction<'a> {
 384    Local { prediction: &'a EditPrediction },
 385    Jump { prediction: &'a EditPrediction },
 386}
 387
 388struct RegisteredBuffer {
 389    snapshot: BufferSnapshot,
 390    _subscriptions: [gpui::Subscription; 2],
 391}
 392
 393struct LastEvent {
 394    old_snapshot: BufferSnapshot,
 395    new_snapshot: BufferSnapshot,
 396    end_edit_anchor: Option<Anchor>,
 397}
 398
 399impl LastEvent {
 400    pub fn finalize(
 401        &self,
 402        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 403        cx: &App,
 404    ) -> Option<Arc<predict_edits_v3::Event>> {
 405        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 406        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 407
 408        let file = self.new_snapshot.file();
 409        let old_file = self.old_snapshot.file();
 410
 411        let in_open_source_repo = [file, old_file].iter().all(|file| {
 412            file.is_some_and(|file| {
 413                license_detection_watchers
 414                    .get(&file.worktree_id(cx))
 415                    .is_some_and(|watcher| watcher.is_project_open_source())
 416            })
 417        });
 418
 419        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 420
 421        if path == old_path && diff.is_empty() {
 422            None
 423        } else {
 424            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 425                old_path,
 426                path,
 427                diff,
 428                in_open_source_repo,
 429                // TODO: Actually detect if this edit was predicted or not
 430                predicted: false,
 431            }))
 432        }
 433    }
 434}
 435
 436fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 437    if let Some(file) = snapshot.file() {
 438        file.full_path(cx).into()
 439    } else {
 440        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 441    }
 442}
 443
 444impl Zeta {
 445    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 446        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 447    }
 448
 449    pub fn global(
 450        client: &Arc<Client>,
 451        user_store: &Entity<UserStore>,
 452        cx: &mut App,
 453    ) -> Entity<Self> {
 454        cx.try_global::<ZetaGlobal>()
 455            .map(|global| global.0.clone())
 456            .unwrap_or_else(|| {
 457                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 458                cx.set_global(ZetaGlobal(zeta.clone()));
 459                zeta
 460            })
 461    }
 462
 463    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 464        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 465        let data_collection_choice = Self::load_data_collection_choice();
 466
 467        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 468        cx.spawn(async move |this, cx| {
 469            while let Some(()) = reject_rx.next().await {
 470                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
 471                    .await
 472                    .log_err();
 473            }
 474            anyhow::Ok(())
 475        })
 476        .detach();
 477
 478        Self {
 479            projects: HashMap::default(),
 480            client,
 481            user_store,
 482            options: DEFAULT_OPTIONS,
 483            llm_token: LlmApiToken::default(),
 484            _llm_token_subscription: cx.subscribe(
 485                &refresh_llm_token_listener,
 486                |this, _listener, _event, cx| {
 487                    let client = this.client.clone();
 488                    let llm_token = this.llm_token.clone();
 489                    cx.spawn(async move |_this, _cx| {
 490                        llm_token.refresh(&client).await?;
 491                        anyhow::Ok(())
 492                    })
 493                    .detach_and_log_err(cx);
 494                },
 495            ),
 496            update_required: false,
 497            debug_tx: None,
 498            #[cfg(feature = "eval-support")]
 499            eval_cache: None,
 500            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 501            sweep_ai: SweepAi::new(cx),
 502            data_collection_choice,
 503            rejected_predictions: Vec::new(),
 504            reject_predictions_debounce_task: None,
 505            reject_predictions_tx: reject_tx,
 506            rated_predictions: Default::default(),
 507            shown_predictions: Default::default(),
 508        }
 509    }
 510
 511    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 512        self.edit_prediction_model = model;
 513    }
 514
 515    pub fn has_sweep_api_token(&self) -> bool {
 516        self.sweep_ai.api_token.is_some()
 517    }
 518
 519    #[cfg(feature = "eval-support")]
 520    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 521        self.eval_cache = Some(cache);
 522    }
 523
 524    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 525        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 526        self.debug_tx = Some(debug_watch_tx);
 527        debug_watch_rx
 528    }
 529
 530    pub fn options(&self) -> &ZetaOptions {
 531        &self.options
 532    }
 533
 534    pub fn set_options(&mut self, options: ZetaOptions) {
 535        self.options = options;
 536    }
 537
 538    pub fn clear_history(&mut self) {
 539        for zeta_project in self.projects.values_mut() {
 540            zeta_project.events.clear();
 541        }
 542    }
 543
 544    pub fn context_for_project(
 545        &self,
 546        project: &Entity<Project>,
 547    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 548        self.projects
 549            .get(&project.entity_id())
 550            .and_then(|project| {
 551                Some(
 552                    project
 553                        .context
 554                        .as_ref()?
 555                        .iter()
 556                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 557                )
 558            })
 559            .into_iter()
 560            .flatten()
 561    }
 562
 563    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 564        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 565            self.user_store.read(cx).edit_prediction_usage()
 566        } else {
 567            None
 568        }
 569    }
 570
 571    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 572        self.get_or_init_zeta_project(project, cx);
 573    }
 574
 575    pub fn register_buffer(
 576        &mut self,
 577        buffer: &Entity<Buffer>,
 578        project: &Entity<Project>,
 579        cx: &mut Context<Self>,
 580    ) {
 581        let zeta_project = self.get_or_init_zeta_project(project, cx);
 582        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 583    }
 584
 585    fn get_or_init_zeta_project(
 586        &mut self,
 587        project: &Entity<Project>,
 588        cx: &mut Context<Self>,
 589    ) -> &mut ZetaProject {
 590        self.projects
 591            .entry(project.entity_id())
 592            .or_insert_with(|| ZetaProject {
 593                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 594                    Some(cx.new(|cx| {
 595                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 596                    }))
 597                } else {
 598                    None
 599                },
 600                events: VecDeque::new(),
 601                last_event: None,
 602                recent_paths: VecDeque::new(),
 603                registered_buffers: HashMap::default(),
 604                current_prediction: None,
 605                cancelled_predictions: HashSet::default(),
 606                pending_predictions: ArrayVec::new(),
 607                next_pending_prediction_id: 0,
 608                last_prediction_refresh: None,
 609                context: None,
 610                refresh_context_task: None,
 611                refresh_context_debounce_task: None,
 612                refresh_context_timestamp: None,
 613                license_detection_watchers: HashMap::default(),
 614                _subscription: cx.subscribe(&project, Self::handle_project_event),
 615            })
 616    }
 617
 618    fn handle_project_event(
 619        &mut self,
 620        project: Entity<Project>,
 621        event: &project::Event,
 622        cx: &mut Context<Self>,
 623    ) {
 624        // TODO [zeta2] init with recent paths
 625        match event {
 626            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 627                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 628                    return;
 629                };
 630                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 631                if let Some(path) = path {
 632                    if let Some(ix) = zeta_project
 633                        .recent_paths
 634                        .iter()
 635                        .position(|probe| probe == &path)
 636                    {
 637                        zeta_project.recent_paths.remove(ix);
 638                    }
 639                    zeta_project.recent_paths.push_front(path);
 640                }
 641            }
 642            project::Event::DiagnosticsUpdated { .. } => {
 643                if cx.has_flag::<Zeta2FeatureFlag>() {
 644                    self.refresh_prediction_from_diagnostics(project, cx);
 645                }
 646            }
 647            _ => (),
 648        }
 649    }
 650
 651    fn register_buffer_impl<'a>(
 652        zeta_project: &'a mut ZetaProject,
 653        buffer: &Entity<Buffer>,
 654        project: &Entity<Project>,
 655        cx: &mut Context<Self>,
 656    ) -> &'a mut RegisteredBuffer {
 657        let buffer_id = buffer.entity_id();
 658
 659        if let Some(file) = buffer.read(cx).file() {
 660            let worktree_id = file.worktree_id(cx);
 661            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 662                zeta_project
 663                    .license_detection_watchers
 664                    .entry(worktree_id)
 665                    .or_insert_with(|| {
 666                        let project_entity_id = project.entity_id();
 667                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 668                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 669                            else {
 670                                return;
 671                            };
 672                            zeta_project.license_detection_watchers.remove(&worktree_id);
 673                        })
 674                        .detach();
 675                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 676                    });
 677            }
 678        }
 679
 680        match zeta_project.registered_buffers.entry(buffer_id) {
 681            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 682            hash_map::Entry::Vacant(entry) => {
 683                let snapshot = buffer.read(cx).snapshot();
 684                let project_entity_id = project.entity_id();
 685                entry.insert(RegisteredBuffer {
 686                    snapshot,
 687                    _subscriptions: [
 688                        cx.subscribe(buffer, {
 689                            let project = project.downgrade();
 690                            move |this, buffer, event, cx| {
 691                                if let language::BufferEvent::Edited = event
 692                                    && let Some(project) = project.upgrade()
 693                                {
 694                                    this.report_changes_for_buffer(&buffer, &project, cx);
 695                                }
 696                            }
 697                        }),
 698                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 699                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 700                            else {
 701                                return;
 702                            };
 703                            zeta_project.registered_buffers.remove(&buffer_id);
 704                        }),
 705                    ],
 706                })
 707            }
 708        }
 709    }
 710
 711    fn report_changes_for_buffer(
 712        &mut self,
 713        buffer: &Entity<Buffer>,
 714        project: &Entity<Project>,
 715        cx: &mut Context<Self>,
 716    ) {
 717        let project_state = self.get_or_init_zeta_project(project, cx);
 718        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 719
 720        let new_snapshot = buffer.read(cx).snapshot();
 721        if new_snapshot.version == registered_buffer.snapshot.version {
 722            return;
 723        }
 724
 725        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 726        let end_edit_anchor = new_snapshot
 727            .anchored_edits_since::<Point>(&old_snapshot.version)
 728            .last()
 729            .map(|(_, range)| range.end);
 730        let events = &mut project_state.events;
 731
 732        if let Some(LastEvent {
 733            new_snapshot: last_new_snapshot,
 734            end_edit_anchor: last_end_edit_anchor,
 735            ..
 736        }) = project_state.last_event.as_mut()
 737        {
 738            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 739                == last_new_snapshot.remote_id()
 740                && old_snapshot.version == last_new_snapshot.version;
 741
 742            let should_coalesce = is_next_snapshot_of_same_buffer
 743                && end_edit_anchor
 744                    .as_ref()
 745                    .zip(last_end_edit_anchor.as_ref())
 746                    .is_some_and(|(a, b)| {
 747                        let a = a.to_point(&new_snapshot);
 748                        let b = b.to_point(&new_snapshot);
 749                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 750                    });
 751
 752            if should_coalesce {
 753                *last_end_edit_anchor = end_edit_anchor;
 754                *last_new_snapshot = new_snapshot;
 755                return;
 756            }
 757        }
 758
 759        if events.len() + 1 >= EVENT_COUNT_MAX {
 760            events.pop_front();
 761        }
 762
 763        if let Some(event) = project_state.last_event.take() {
 764            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 765        }
 766
 767        project_state.last_event = Some(LastEvent {
 768            old_snapshot,
 769            new_snapshot,
 770            end_edit_anchor,
 771        });
 772    }
 773
 774    fn current_prediction_for_buffer(
 775        &self,
 776        buffer: &Entity<Buffer>,
 777        project: &Entity<Project>,
 778        cx: &App,
 779    ) -> Option<BufferEditPrediction<'_>> {
 780        let project_state = self.projects.get(&project.entity_id())?;
 781
 782        let CurrentEditPrediction {
 783            requested_by,
 784            prediction,
 785            ..
 786        } = project_state.current_prediction.as_ref()?;
 787
 788        if prediction.targets_buffer(buffer.read(cx)) {
 789            Some(BufferEditPrediction::Local { prediction })
 790        } else {
 791            let show_jump = match requested_by {
 792                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 793                    requested_by_buffer_id == &buffer.entity_id()
 794                }
 795                PredictionRequestedBy::DiagnosticsUpdate => true,
 796            };
 797
 798            if show_jump {
 799                Some(BufferEditPrediction::Jump { prediction })
 800            } else {
 801                None
 802            }
 803        }
 804    }
 805
 806    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 807        match self.edit_prediction_model {
 808            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 809            ZetaEditPredictionModel::Sweep => return,
 810        }
 811
 812        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 813            return;
 814        };
 815
 816        let Some(prediction) = project_state.current_prediction.take() else {
 817            return;
 818        };
 819        let request_id = prediction.prediction.id.to_string();
 820        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 821            self.cancel_pending_prediction(pending_prediction, cx);
 822        }
 823
 824        let client = self.client.clone();
 825        let llm_token = self.llm_token.clone();
 826        let app_version = AppVersion::global(cx);
 827        cx.spawn(async move |this, cx| {
 828            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 829                http_client::Url::parse(&predict_edits_url)?
 830            } else {
 831                client
 832                    .http_client()
 833                    .build_zed_llm_url("/predict_edits/accept", &[])?
 834            };
 835
 836            let response = cx
 837                .background_spawn(Self::send_api_request::<()>(
 838                    move |builder| {
 839                        let req = builder.uri(url.as_ref()).body(
 840                            serde_json::to_string(&AcceptEditPredictionBody {
 841                                request_id: request_id.clone(),
 842                            })?
 843                            .into(),
 844                        );
 845                        Ok(req?)
 846                    },
 847                    client,
 848                    llm_token,
 849                    app_version,
 850                ))
 851                .await;
 852
 853            Self::handle_api_response(&this, response, cx)?;
 854            anyhow::Ok(())
 855        })
 856        .detach_and_log_err(cx);
 857    }
 858
 859    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 860        match self.edit_prediction_model {
 861            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 862            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 863        }
 864
 865        let client = self.client.clone();
 866        let llm_token = self.llm_token.clone();
 867        let app_version = AppVersion::global(cx);
 868        let last_rejection = self.rejected_predictions.last().cloned();
 869        let Some(last_rejection) = last_rejection else {
 870            return Task::ready(anyhow::Ok(()));
 871        };
 872
 873        let body = serde_json::to_string(&RejectEditPredictionsBody {
 874            rejections: self.rejected_predictions.clone(),
 875        })
 876        .ok();
 877
 878        cx.spawn(async move |this, cx| {
 879            let url = client
 880                .http_client()
 881                .build_zed_llm_url("/predict_edits/reject", &[])?;
 882
 883            cx.background_spawn(Self::send_api_request::<()>(
 884                move |builder| {
 885                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 886                    Ok(req?)
 887                },
 888                client,
 889                llm_token,
 890                app_version,
 891            ))
 892            .await
 893            .context("Failed to reject edit predictions")?;
 894
 895            this.update(cx, |this, _| {
 896                if let Some(ix) = this
 897                    .rejected_predictions
 898                    .iter()
 899                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 900                {
 901                    this.rejected_predictions.drain(..ix + 1);
 902                }
 903            })
 904        })
 905    }
 906
 907    fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 908        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 909            project_state.pending_predictions.clear();
 910            if let Some(prediction) = project_state.current_prediction.take() {
 911                self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
 912            }
 913        };
 914    }
 915
 916    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 917        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 918            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 919                if !current_prediction.was_shown {
 920                    current_prediction.was_shown = true;
 921                    self.shown_predictions
 922                        .push_front(current_prediction.prediction.clone());
 923                    if self.shown_predictions.len() > 50 {
 924                        let completion = self.shown_predictions.pop_back().unwrap();
 925                        self.rated_predictions.remove(&completion.id);
 926                    }
 927                }
 928            }
 929        }
 930    }
 931
 932    fn discard_prediction(
 933        &mut self,
 934        prediction_id: EditPredictionId,
 935        was_shown: bool,
 936        cx: &mut Context<Self>,
 937    ) {
 938        self.rejected_predictions.push(EditPredictionRejection {
 939            request_id: prediction_id.to_string(),
 940            was_shown,
 941        });
 942
 943        let reached_request_limit =
 944            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
 945        let reject_tx = self.reject_predictions_tx.clone();
 946        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 947            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
 948            if !reached_request_limit {
 949                cx.background_executor()
 950                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
 951                    .await;
 952            }
 953            reject_tx.unbounded_send(()).log_err();
 954        }));
 955    }
 956
 957    fn cancel_pending_prediction(
 958        &self,
 959        pending_prediction: PendingPrediction,
 960        cx: &mut Context<Self>,
 961    ) {
 962        cx.spawn(async move |this, cx| {
 963            let Some(prediction_id) = pending_prediction.task.await else {
 964                return;
 965            };
 966
 967            this.update(cx, |this, cx| {
 968                this.discard_prediction(prediction_id, false, cx);
 969            })
 970            .ok();
 971        })
 972        .detach()
 973    }
 974
 975    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
 976        self.projects
 977            .get(&project.entity_id())
 978            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
 979    }
 980
 981    pub fn refresh_prediction_from_buffer(
 982        &mut self,
 983        project: Entity<Project>,
 984        buffer: Entity<Buffer>,
 985        position: language::Anchor,
 986        cx: &mut Context<Self>,
 987    ) {
 988        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
 989            let Some(request_task) = this
 990                .update(cx, |this, cx| {
 991                    this.request_prediction(&project, &buffer, position, cx)
 992                })
 993                .log_err()
 994            else {
 995                return Task::ready(anyhow::Ok(None));
 996            };
 997
 998            let project = project.clone();
 999            cx.spawn(async move |cx| {
1000                if let Some(prediction) = request_task.await? {
1001                    let id = prediction.id.clone();
1002                    this.update(cx, |this, cx| {
1003                        let project_state = this
1004                            .projects
1005                            .get_mut(&project.entity_id())
1006                            .context("Project not found")?;
1007
1008                        let new_prediction = CurrentEditPrediction {
1009                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
1010                            prediction: prediction,
1011                            was_shown: false,
1012                        };
1013
1014                        if project_state
1015                            .current_prediction
1016                            .as_ref()
1017                            .is_none_or(|old_prediction| {
1018                                new_prediction.should_replace_prediction(&old_prediction, cx)
1019                            })
1020                        {
1021                            project_state.current_prediction = Some(new_prediction);
1022                            cx.notify();
1023                        }
1024                        anyhow::Ok(())
1025                    })??;
1026                    Ok(Some(id))
1027                } else {
1028                    Ok(None)
1029                }
1030            })
1031        })
1032    }
1033
1034    pub fn refresh_prediction_from_diagnostics(
1035        &mut self,
1036        project: Entity<Project>,
1037        cx: &mut Context<Self>,
1038    ) {
1039        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1040            return;
1041        };
1042
1043        // Prefer predictions from buffer
1044        if zeta_project.current_prediction.is_some() {
1045            return;
1046        };
1047
1048        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1049            let Some(open_buffer_task) = project
1050                .update(cx, |project, cx| {
1051                    project
1052                        .active_entry()
1053                        .and_then(|entry| project.path_for_entry(entry, cx))
1054                        .map(|path| project.open_buffer(path, cx))
1055                })
1056                .log_err()
1057                .flatten()
1058            else {
1059                return Task::ready(anyhow::Ok(None));
1060            };
1061
1062            cx.spawn(async move |cx| {
1063                let active_buffer = open_buffer_task.await?;
1064                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1065
1066                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1067                    active_buffer,
1068                    &snapshot,
1069                    Default::default(),
1070                    Default::default(),
1071                    &project,
1072                    cx,
1073                )
1074                .await?
1075                else {
1076                    return anyhow::Ok(None);
1077                };
1078
1079                let Some(prediction) = this
1080                    .update(cx, |this, cx| {
1081                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
1082                    })?
1083                    .await?
1084                else {
1085                    return anyhow::Ok(None);
1086                };
1087
1088                let id = prediction.id.clone();
1089                this.update(cx, |this, cx| {
1090                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1091                        zeta_project.current_prediction.get_or_insert_with(|| {
1092                            cx.notify();
1093                            CurrentEditPrediction {
1094                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
1095                                prediction,
1096                                was_shown: false,
1097                            }
1098                        });
1099                    }
1100                })?;
1101
1102                anyhow::Ok(Some(id))
1103            })
1104        });
1105    }
1106
1107    #[cfg(not(test))]
1108    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1109    #[cfg(test)]
1110    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1111
1112    fn queue_prediction_refresh(
1113        &mut self,
1114        project: Entity<Project>,
1115        throttle_entity: EntityId,
1116        cx: &mut Context<Self>,
1117        do_refresh: impl FnOnce(
1118            WeakEntity<Self>,
1119            &mut AsyncApp,
1120        ) -> Task<Result<Option<EditPredictionId>>>
1121        + 'static,
1122    ) {
1123        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1124        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1125        zeta_project.next_pending_prediction_id += 1;
1126        let last_request = zeta_project.last_prediction_refresh;
1127
1128        let task = cx.spawn(async move |this, cx| {
1129            if let Some((last_entity, last_timestamp)) = last_request
1130                && throttle_entity == last_entity
1131                && let Some(timeout) =
1132                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1133            {
1134                cx.background_executor().timer(timeout).await;
1135            }
1136
1137            // If this task was cancelled before the throttle timeout expired,
1138            // do not perform a request.
1139            let mut is_cancelled = true;
1140            this.update(cx, |this, cx| {
1141                let project_state = this.get_or_init_zeta_project(&project, cx);
1142                if !project_state
1143                    .cancelled_predictions
1144                    .remove(&pending_prediction_id)
1145                {
1146                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1147                    is_cancelled = false;
1148                }
1149            })
1150            .ok();
1151            if is_cancelled {
1152                return None;
1153            }
1154
1155            let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
1156
1157            // When a prediction completes, remove it from the pending list, and cancel
1158            // any pending predictions that were enqueued before it.
1159            this.update(cx, |this, cx| {
1160                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1161                zeta_project
1162                    .cancelled_predictions
1163                    .remove(&pending_prediction_id);
1164
1165                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1166                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1167                    if pending_prediction.id == pending_prediction_id {
1168                        pending_predictions.remove(ix);
1169                        for pending_prediction in pending_predictions.drain(0..ix) {
1170                            this.cancel_pending_prediction(pending_prediction, cx)
1171                        }
1172                        break;
1173                    }
1174                }
1175                this.get_or_init_zeta_project(&project, cx)
1176                    .pending_predictions = pending_predictions;
1177                cx.notify();
1178            })
1179            .ok();
1180
1181            edit_prediction_id
1182        });
1183
1184        if zeta_project.pending_predictions.len() <= 1 {
1185            zeta_project.pending_predictions.push(PendingPrediction {
1186                id: pending_prediction_id,
1187                task,
1188            });
1189        } else if zeta_project.pending_predictions.len() == 2 {
1190            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1191            zeta_project.pending_predictions.push(PendingPrediction {
1192                id: pending_prediction_id,
1193                task,
1194            });
1195            zeta_project
1196                .cancelled_predictions
1197                .insert(pending_prediction.id);
1198            self.cancel_pending_prediction(pending_prediction, cx);
1199        }
1200    }
1201
1202    pub fn request_prediction(
1203        &mut self,
1204        project: &Entity<Project>,
1205        active_buffer: &Entity<Buffer>,
1206        position: language::Anchor,
1207        cx: &mut Context<Self>,
1208    ) -> Task<Result<Option<EditPrediction>>> {
1209        self.request_prediction_internal(
1210            project.clone(),
1211            active_buffer.clone(),
1212            position,
1213            cx.has_flag::<Zeta2FeatureFlag>(),
1214            cx,
1215        )
1216    }
1217
1218    fn request_prediction_internal(
1219        &mut self,
1220        project: Entity<Project>,
1221        active_buffer: Entity<Buffer>,
1222        position: language::Anchor,
1223        allow_jump: bool,
1224        cx: &mut Context<Self>,
1225    ) -> Task<Result<Option<EditPrediction>>> {
1226        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1227
1228        self.get_or_init_zeta_project(&project, cx);
1229        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1230        let events = zeta_project.events(cx);
1231        let has_events = !events.is_empty();
1232
1233        let snapshot = active_buffer.read(cx).snapshot();
1234        let cursor_point = position.to_point(&snapshot);
1235        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1236        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1237        let diagnostic_search_range =
1238            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1239
1240        let task = match self.edit_prediction_model {
1241            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1242                self,
1243                &project,
1244                &active_buffer,
1245                snapshot.clone(),
1246                position,
1247                events,
1248                cx,
1249            ),
1250            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1251                &project,
1252                &active_buffer,
1253                snapshot.clone(),
1254                position,
1255                events,
1256                cx,
1257            ),
1258            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1259                &project,
1260                &active_buffer,
1261                snapshot.clone(),
1262                position,
1263                events,
1264                &zeta_project.recent_paths,
1265                diagnostic_search_range.clone(),
1266                cx,
1267            ),
1268        };
1269
1270        cx.spawn(async move |this, cx| {
1271            let prediction = task
1272                .await?
1273                .filter(|prediction| !prediction.edits.is_empty());
1274
1275            if prediction.is_none() && allow_jump {
1276                let cursor_point = position.to_point(&snapshot);
1277                if has_events
1278                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1279                        active_buffer.clone(),
1280                        &snapshot,
1281                        diagnostic_search_range,
1282                        cursor_point,
1283                        &project,
1284                        cx,
1285                    )
1286                    .await?
1287                {
1288                    return this
1289                        .update(cx, |this, cx| {
1290                            this.request_prediction_internal(
1291                                project,
1292                                jump_buffer,
1293                                jump_position,
1294                                false,
1295                                cx,
1296                            )
1297                        })?
1298                        .await;
1299                }
1300
1301                return anyhow::Ok(None);
1302            }
1303
1304            Ok(prediction)
1305        })
1306    }
1307
1308    async fn next_diagnostic_location(
1309        active_buffer: Entity<Buffer>,
1310        active_buffer_snapshot: &BufferSnapshot,
1311        active_buffer_diagnostic_search_range: Range<Point>,
1312        active_buffer_cursor_point: Point,
1313        project: &Entity<Project>,
1314        cx: &mut AsyncApp,
1315    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1316        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1317        let mut jump_location = active_buffer_snapshot
1318            .diagnostic_groups(None)
1319            .into_iter()
1320            .filter_map(|(_, group)| {
1321                let range = &group.entries[group.primary_ix]
1322                    .range
1323                    .to_point(&active_buffer_snapshot);
1324                if range.overlaps(&active_buffer_diagnostic_search_range) {
1325                    None
1326                } else {
1327                    Some(range.start)
1328                }
1329            })
1330            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1331            .map(|position| {
1332                (
1333                    active_buffer.clone(),
1334                    active_buffer_snapshot.anchor_before(position),
1335                )
1336            });
1337
1338        if jump_location.is_none() {
1339            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1340                let file = buffer.file()?;
1341
1342                Some(ProjectPath {
1343                    worktree_id: file.worktree_id(cx),
1344                    path: file.path().clone(),
1345                })
1346            })?;
1347
1348            let buffer_task = project.update(cx, |project, cx| {
1349                let (path, _, _) = project
1350                    .diagnostic_summaries(false, cx)
1351                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1352                    .max_by_key(|(path, _, _)| {
1353                        // find the buffer with errors that shares most parent directories
1354                        path.path
1355                            .components()
1356                            .zip(
1357                                active_buffer_path
1358                                    .as_ref()
1359                                    .map(|p| p.path.components())
1360                                    .unwrap_or_default(),
1361                            )
1362                            .take_while(|(a, b)| a == b)
1363                            .count()
1364                    })?;
1365
1366                Some(project.open_buffer(path, cx))
1367            })?;
1368
1369            if let Some(buffer_task) = buffer_task {
1370                let closest_buffer = buffer_task.await?;
1371
1372                jump_location = closest_buffer
1373                    .read_with(cx, |buffer, _cx| {
1374                        buffer
1375                            .buffer_diagnostics(None)
1376                            .into_iter()
1377                            .min_by_key(|entry| entry.diagnostic.severity)
1378                            .map(|entry| entry.range.start)
1379                    })?
1380                    .map(|position| (closest_buffer, position));
1381            }
1382        }
1383
1384        anyhow::Ok(jump_location)
1385    }
1386
1387    fn request_prediction_with_zeta2(
1388        &mut self,
1389        project: &Entity<Project>,
1390        active_buffer: &Entity<Buffer>,
1391        active_snapshot: BufferSnapshot,
1392        position: language::Anchor,
1393        events: Vec<Arc<Event>>,
1394        cx: &mut Context<Self>,
1395    ) -> Task<Result<Option<EditPrediction>>> {
1396        let project_state = self.projects.get(&project.entity_id());
1397
1398        let index_state = project_state.and_then(|state| {
1399            state
1400                .syntax_index
1401                .as_ref()
1402                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1403        });
1404        let options = self.options.clone();
1405        let buffer_snapshotted_at = Instant::now();
1406        let Some(excerpt_path) = active_snapshot
1407            .file()
1408            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1409        else {
1410            return Task::ready(Err(anyhow!("No file path for excerpt")));
1411        };
1412        let client = self.client.clone();
1413        let llm_token = self.llm_token.clone();
1414        let app_version = AppVersion::global(cx);
1415        let worktree_snapshots = project
1416            .read(cx)
1417            .worktrees(cx)
1418            .map(|worktree| worktree.read(cx).snapshot())
1419            .collect::<Vec<_>>();
1420        let debug_tx = self.debug_tx.clone();
1421
1422        let diagnostics = active_snapshot.diagnostic_sets().clone();
1423
1424        let file = active_buffer.read(cx).file();
1425        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1426            let mut path = f.worktree.read(cx).absolutize(&f.path);
1427            if path.pop() { Some(path) } else { None }
1428        });
1429
1430        // TODO data collection
1431        let can_collect_data = file
1432            .as_ref()
1433            .map_or(false, |file| self.can_collect_file(project, file, cx));
1434
1435        let empty_context_files = HashMap::default();
1436        let context_files = project_state
1437            .and_then(|project_state| project_state.context.as_ref())
1438            .unwrap_or(&empty_context_files);
1439
1440        #[cfg(feature = "eval-support")]
1441        let parsed_fut = futures::future::join_all(
1442            context_files
1443                .keys()
1444                .map(|buffer| buffer.read(cx).parsing_idle()),
1445        );
1446
1447        let mut included_files = context_files
1448            .iter()
1449            .filter_map(|(buffer_entity, ranges)| {
1450                let buffer = buffer_entity.read(cx);
1451                Some((
1452                    buffer_entity.clone(),
1453                    buffer.snapshot(),
1454                    buffer.file()?.full_path(cx).into(),
1455                    ranges.clone(),
1456                ))
1457            })
1458            .collect::<Vec<_>>();
1459
1460        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1461            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1462        });
1463
1464        #[cfg(feature = "eval-support")]
1465        let eval_cache = self.eval_cache.clone();
1466
1467        let request_task = cx.background_spawn({
1468            let active_buffer = active_buffer.clone();
1469            async move {
1470                #[cfg(feature = "eval-support")]
1471                parsed_fut.await;
1472
1473                let index_state = if let Some(index_state) = index_state {
1474                    Some(index_state.lock_owned().await)
1475                } else {
1476                    None
1477                };
1478
1479                let cursor_offset = position.to_offset(&active_snapshot);
1480                let cursor_point = cursor_offset.to_point(&active_snapshot);
1481
1482                let before_retrieval = Instant::now();
1483
1484                let (diagnostic_groups, diagnostic_groups_truncated) =
1485                    Self::gather_nearby_diagnostics(
1486                        cursor_offset,
1487                        &diagnostics,
1488                        &active_snapshot,
1489                        options.max_diagnostic_bytes,
1490                    );
1491
1492                let cloud_request = match options.context {
1493                    ContextMode::Agentic(context_options) => {
1494                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1495                            cursor_point,
1496                            &active_snapshot,
1497                            &context_options.excerpt,
1498                            index_state.as_deref(),
1499                        ) else {
1500                            return Ok((None, None));
1501                        };
1502
1503                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1504                            ..active_snapshot.anchor_before(excerpt.range.end);
1505
1506                        if let Some(buffer_ix) =
1507                            included_files.iter().position(|(_, snapshot, _, _)| {
1508                                snapshot.remote_id() == active_snapshot.remote_id()
1509                            })
1510                        {
1511                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1512                            ranges.push(excerpt_anchor_range);
1513                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1514                            let last_ix = included_files.len() - 1;
1515                            included_files.swap(buffer_ix, last_ix);
1516                        } else {
1517                            included_files.push((
1518                                active_buffer.clone(),
1519                                active_snapshot.clone(),
1520                                excerpt_path.clone(),
1521                                vec![excerpt_anchor_range],
1522                            ));
1523                        }
1524
1525                        let included_files = included_files
1526                            .iter()
1527                            .map(|(_, snapshot, path, ranges)| {
1528                                let ranges = ranges
1529                                    .iter()
1530                                    .map(|range| {
1531                                        let point_range = range.to_point(&snapshot);
1532                                        Line(point_range.start.row)..Line(point_range.end.row)
1533                                    })
1534                                    .collect::<Vec<_>>();
1535                                let excerpts = assemble_excerpts(&snapshot, ranges);
1536                                predict_edits_v3::IncludedFile {
1537                                    path: path.clone(),
1538                                    max_row: Line(snapshot.max_point().row),
1539                                    excerpts,
1540                                }
1541                            })
1542                            .collect::<Vec<_>>();
1543
1544                        predict_edits_v3::PredictEditsRequest {
1545                            excerpt_path,
1546                            excerpt: String::new(),
1547                            excerpt_line_range: Line(0)..Line(0),
1548                            excerpt_range: 0..0,
1549                            cursor_point: predict_edits_v3::Point {
1550                                line: predict_edits_v3::Line(cursor_point.row),
1551                                column: cursor_point.column,
1552                            },
1553                            included_files,
1554                            referenced_declarations: vec![],
1555                            events,
1556                            can_collect_data,
1557                            diagnostic_groups,
1558                            diagnostic_groups_truncated,
1559                            debug_info: debug_tx.is_some(),
1560                            prompt_max_bytes: Some(options.max_prompt_bytes),
1561                            prompt_format: options.prompt_format,
1562                            // TODO [zeta2]
1563                            signatures: vec![],
1564                            excerpt_parent: None,
1565                            git_info: None,
1566                        }
1567                    }
1568                    ContextMode::Syntax(context_options) => {
1569                        let Some(context) = EditPredictionContext::gather_context(
1570                            cursor_point,
1571                            &active_snapshot,
1572                            parent_abs_path.as_deref(),
1573                            &context_options,
1574                            index_state.as_deref(),
1575                        ) else {
1576                            return Ok((None, None));
1577                        };
1578
1579                        make_syntax_context_cloud_request(
1580                            excerpt_path,
1581                            context,
1582                            events,
1583                            can_collect_data,
1584                            diagnostic_groups,
1585                            diagnostic_groups_truncated,
1586                            None,
1587                            debug_tx.is_some(),
1588                            &worktree_snapshots,
1589                            index_state.as_deref(),
1590                            Some(options.max_prompt_bytes),
1591                            options.prompt_format,
1592                        )
1593                    }
1594                };
1595
1596                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1597
1598                let inputs = EditPredictionInputs {
1599                    included_files: cloud_request.included_files,
1600                    events: cloud_request.events,
1601                    cursor_point: cloud_request.cursor_point,
1602                    cursor_path: cloud_request.excerpt_path,
1603                };
1604
1605                let retrieval_time = Instant::now() - before_retrieval;
1606
1607                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1608                    let (response_tx, response_rx) = oneshot::channel();
1609
1610                    debug_tx
1611                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1612                            ZetaEditPredictionDebugInfo {
1613                                inputs: inputs.clone(),
1614                                retrieval_time,
1615                                buffer: active_buffer.downgrade(),
1616                                local_prompt: match prompt_result.as_ref() {
1617                                    Ok((prompt, _)) => Ok(prompt.clone()),
1618                                    Err(err) => Err(err.to_string()),
1619                                },
1620                                position,
1621                                response_rx,
1622                            },
1623                        ))
1624                        .ok();
1625                    Some(response_tx)
1626                } else {
1627                    None
1628                };
1629
1630                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1631                    if let Some(debug_response_tx) = debug_response_tx {
1632                        debug_response_tx
1633                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1634                            .ok();
1635                    }
1636                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1637                }
1638
1639                let (prompt, _) = prompt_result?;
1640                let generation_params =
1641                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1642                let request = open_ai::Request {
1643                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1644                    messages: vec![open_ai::RequestMessage::User {
1645                        content: open_ai::MessageContent::Plain(prompt),
1646                    }],
1647                    stream: false,
1648                    max_completion_tokens: None,
1649                    stop: generation_params.stop.unwrap_or_default(),
1650                    temperature: generation_params.temperature.unwrap_or(0.7),
1651                    tool_choice: None,
1652                    parallel_tool_calls: None,
1653                    tools: vec![],
1654                    prompt_cache_key: None,
1655                    reasoning_effort: None,
1656                };
1657
1658                log::trace!("Sending edit prediction request");
1659
1660                let before_request = Instant::now();
1661                let response = Self::send_raw_llm_request(
1662                    request,
1663                    client,
1664                    llm_token,
1665                    app_version,
1666                    #[cfg(feature = "eval-support")]
1667                    eval_cache,
1668                    #[cfg(feature = "eval-support")]
1669                    EvalCacheEntryKind::Prediction,
1670                )
1671                .await;
1672                let received_response_at = Instant::now();
1673                let request_time = received_response_at - before_request;
1674
1675                log::trace!("Got edit prediction response");
1676
1677                if let Some(debug_response_tx) = debug_response_tx {
1678                    debug_response_tx
1679                        .send((
1680                            response
1681                                .as_ref()
1682                                .map_err(|err| err.to_string())
1683                                .map(|response| response.0.clone()),
1684                            request_time,
1685                        ))
1686                        .ok();
1687                }
1688
1689                let (res, usage) = response?;
1690                let request_id = EditPredictionId(res.id.clone().into());
1691                let Some(mut output_text) = text_from_response(res) else {
1692                    return Ok((None, usage));
1693                };
1694
1695                if output_text.contains(CURSOR_MARKER) {
1696                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1697                    output_text = output_text.replace(CURSOR_MARKER, "");
1698                }
1699
1700                let get_buffer_from_context = |path: &Path| {
1701                    included_files
1702                        .iter()
1703                        .find_map(|(_, buffer, probe_path, ranges)| {
1704                            if probe_path.as_ref() == path {
1705                                Some((buffer, ranges.as_slice()))
1706                            } else {
1707                                None
1708                            }
1709                        })
1710                };
1711
1712                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1713                    PromptFormat::NumLinesUniDiff => {
1714                        // TODO: Implement parsing of multi-file diffs
1715                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1716                    }
1717                    PromptFormat::Minimal
1718                    | PromptFormat::MinimalQwen
1719                    | PromptFormat::SeedCoder1120 => {
1720                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1721                            let edits = vec![];
1722                            (&active_snapshot, edits)
1723                        } else {
1724                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1725                        }
1726                    }
1727                    PromptFormat::OldTextNewText => {
1728                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1729                            .await?
1730                    }
1731                    _ => {
1732                        bail!("unsupported prompt format {}", options.prompt_format)
1733                    }
1734                };
1735
1736                let edited_buffer = included_files
1737                    .iter()
1738                    .find_map(|(buffer, snapshot, _, _)| {
1739                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1740                            Some(buffer.clone())
1741                        } else {
1742                            None
1743                        }
1744                    })
1745                    .context("Failed to find buffer in included_buffers")?;
1746
1747                anyhow::Ok((
1748                    Some((
1749                        request_id,
1750                        inputs,
1751                        edited_buffer,
1752                        edited_buffer_snapshot.clone(),
1753                        edits,
1754                        received_response_at,
1755                    )),
1756                    usage,
1757                ))
1758            }
1759        });
1760
1761        cx.spawn({
1762            async move |this, cx| {
1763                let Some((
1764                    id,
1765                    inputs,
1766                    edited_buffer,
1767                    edited_buffer_snapshot,
1768                    edits,
1769                    received_response_at,
1770                )) = Self::handle_api_response(&this, request_task.await, cx)?
1771                else {
1772                    return Ok(None);
1773                };
1774
1775                // TODO telemetry: duration, etc
1776                Ok(EditPrediction::new(
1777                    id,
1778                    &edited_buffer,
1779                    &edited_buffer_snapshot,
1780                    edits.into(),
1781                    buffer_snapshotted_at,
1782                    received_response_at,
1783                    inputs,
1784                    cx,
1785                )
1786                .await)
1787            }
1788        })
1789    }
1790
1791    async fn send_raw_llm_request(
1792        request: open_ai::Request,
1793        client: Arc<Client>,
1794        llm_token: LlmApiToken,
1795        app_version: Version,
1796        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1797        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1798    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1799        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1800            http_client::Url::parse(&predict_edits_url)?
1801        } else {
1802            client
1803                .http_client()
1804                .build_zed_llm_url("/predict_edits/raw", &[])?
1805        };
1806
1807        #[cfg(feature = "eval-support")]
1808        let cache_key = if let Some(cache) = eval_cache {
1809            use collections::FxHasher;
1810            use std::hash::{Hash, Hasher};
1811
1812            let mut hasher = FxHasher::default();
1813            url.hash(&mut hasher);
1814            let request_str = serde_json::to_string_pretty(&request)?;
1815            request_str.hash(&mut hasher);
1816            let hash = hasher.finish();
1817
1818            let key = (eval_cache_kind, hash);
1819            if let Some(response_str) = cache.read(key) {
1820                return Ok((serde_json::from_str(&response_str)?, None));
1821            }
1822
1823            Some((cache, request_str, key))
1824        } else {
1825            None
1826        };
1827
1828        let (response, usage) = Self::send_api_request(
1829            |builder| {
1830                let req = builder
1831                    .uri(url.as_ref())
1832                    .body(serde_json::to_string(&request)?.into());
1833                Ok(req?)
1834            },
1835            client,
1836            llm_token,
1837            app_version,
1838        )
1839        .await?;
1840
1841        #[cfg(feature = "eval-support")]
1842        if let Some((cache, request, key)) = cache_key {
1843            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1844        }
1845
1846        Ok((response, usage))
1847    }
1848
1849    fn handle_api_response<T>(
1850        this: &WeakEntity<Self>,
1851        response: Result<(T, Option<EditPredictionUsage>)>,
1852        cx: &mut gpui::AsyncApp,
1853    ) -> Result<T> {
1854        match response {
1855            Ok((data, usage)) => {
1856                if let Some(usage) = usage {
1857                    this.update(cx, |this, cx| {
1858                        this.user_store.update(cx, |user_store, cx| {
1859                            user_store.update_edit_prediction_usage(usage, cx);
1860                        });
1861                    })
1862                    .ok();
1863                }
1864                Ok(data)
1865            }
1866            Err(err) => {
1867                if err.is::<ZedUpdateRequiredError>() {
1868                    cx.update(|cx| {
1869                        this.update(cx, |this, _cx| {
1870                            this.update_required = true;
1871                        })
1872                        .ok();
1873
1874                        let error_message: SharedString = err.to_string().into();
1875                        show_app_notification(
1876                            NotificationId::unique::<ZedUpdateRequiredError>(),
1877                            cx,
1878                            move |cx| {
1879                                cx.new(|cx| {
1880                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1881                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1882                                })
1883                            },
1884                        );
1885                    })
1886                    .ok();
1887                }
1888                Err(err)
1889            }
1890        }
1891    }
1892
1893    async fn send_api_request<Res>(
1894        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1895        client: Arc<Client>,
1896        llm_token: LlmApiToken,
1897        app_version: Version,
1898    ) -> Result<(Res, Option<EditPredictionUsage>)>
1899    where
1900        Res: DeserializeOwned,
1901    {
1902        let http_client = client.http_client();
1903        let mut token = llm_token.acquire(&client).await?;
1904        let mut did_retry = false;
1905
1906        loop {
1907            let request_builder = http_client::Request::builder().method(Method::POST);
1908
1909            let request = build(
1910                request_builder
1911                    .header("Content-Type", "application/json")
1912                    .header("Authorization", format!("Bearer {}", token))
1913                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1914            )?;
1915
1916            let mut response = http_client.send(request).await?;
1917
1918            if let Some(minimum_required_version) = response
1919                .headers()
1920                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1921                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1922            {
1923                anyhow::ensure!(
1924                    app_version >= minimum_required_version,
1925                    ZedUpdateRequiredError {
1926                        minimum_version: minimum_required_version
1927                    }
1928                );
1929            }
1930
1931            if response.status().is_success() {
1932                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1933
1934                let mut body = Vec::new();
1935                response.body_mut().read_to_end(&mut body).await?;
1936                return Ok((serde_json::from_slice(&body)?, usage));
1937            } else if !did_retry
1938                && response
1939                    .headers()
1940                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1941                    .is_some()
1942            {
1943                did_retry = true;
1944                token = llm_token.refresh(&client).await?;
1945            } else {
1946                let mut body = String::new();
1947                response.body_mut().read_to_string(&mut body).await?;
1948                anyhow::bail!(
1949                    "Request failed with status: {:?}\nBody: {}",
1950                    response.status(),
1951                    body
1952                );
1953            }
1954        }
1955    }
1956
1957    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1958    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1959
1960    // Refresh the related excerpts when the user just beguns editing after
1961    // an idle period, and after they pause editing.
1962    fn refresh_context_if_needed(
1963        &mut self,
1964        project: &Entity<Project>,
1965        buffer: &Entity<language::Buffer>,
1966        cursor_position: language::Anchor,
1967        cx: &mut Context<Self>,
1968    ) {
1969        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1970            return;
1971        }
1972
1973        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1974            return;
1975        };
1976
1977        let now = Instant::now();
1978        let was_idle = zeta_project
1979            .refresh_context_timestamp
1980            .map_or(true, |timestamp| {
1981                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1982            });
1983        zeta_project.refresh_context_timestamp = Some(now);
1984        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1985            let buffer = buffer.clone();
1986            let project = project.clone();
1987            async move |this, cx| {
1988                if was_idle {
1989                    log::debug!("refetching edit prediction context after idle");
1990                } else {
1991                    cx.background_executor()
1992                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1993                        .await;
1994                    log::debug!("refetching edit prediction context after pause");
1995                }
1996                this.update(cx, |this, cx| {
1997                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1998
1999                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2000                        zeta_project.refresh_context_task = Some(task.log_err());
2001                    };
2002                })
2003                .ok()
2004            }
2005        }));
2006    }
2007
2008    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2009    // and avoid spawning more than one concurrent task.
2010    pub fn refresh_context(
2011        &mut self,
2012        project: Entity<Project>,
2013        buffer: Entity<language::Buffer>,
2014        cursor_position: language::Anchor,
2015        cx: &mut Context<Self>,
2016    ) -> Task<Result<()>> {
2017        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2018            return Task::ready(anyhow::Ok(()));
2019        };
2020
2021        let ContextMode::Agentic(options) = &self.options().context else {
2022            return Task::ready(anyhow::Ok(()));
2023        };
2024
2025        let snapshot = buffer.read(cx).snapshot();
2026        let cursor_point = cursor_position.to_point(&snapshot);
2027        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2028            cursor_point,
2029            &snapshot,
2030            &options.excerpt,
2031            None,
2032        ) else {
2033            return Task::ready(Ok(()));
2034        };
2035
2036        let app_version = AppVersion::global(cx);
2037        let client = self.client.clone();
2038        let llm_token = self.llm_token.clone();
2039        let debug_tx = self.debug_tx.clone();
2040        let current_file_path: Arc<Path> = snapshot
2041            .file()
2042            .map(|f| f.full_path(cx).into())
2043            .unwrap_or_else(|| Path::new("untitled").into());
2044
2045        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2046            predict_edits_v3::PlanContextRetrievalRequest {
2047                excerpt: cursor_excerpt.text(&snapshot).body,
2048                excerpt_path: current_file_path,
2049                excerpt_line_range: cursor_excerpt.line_range,
2050                cursor_file_max_row: Line(snapshot.max_point().row),
2051                events: zeta_project.events(cx),
2052            },
2053        ) {
2054            Ok(prompt) => prompt,
2055            Err(err) => {
2056                return Task::ready(Err(err));
2057            }
2058        };
2059
2060        if let Some(debug_tx) = &debug_tx {
2061            debug_tx
2062                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2063                    ZetaContextRetrievalStartedDebugInfo {
2064                        project: project.clone(),
2065                        timestamp: Instant::now(),
2066                        search_prompt: prompt.clone(),
2067                    },
2068                ))
2069                .ok();
2070        }
2071
2072        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2073            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2074                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2075            );
2076
2077            let description = schema
2078                .get("description")
2079                .and_then(|description| description.as_str())
2080                .unwrap()
2081                .to_string();
2082
2083            (schema.into(), description)
2084        });
2085
2086        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2087
2088        let request = open_ai::Request {
2089            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2090            messages: vec![open_ai::RequestMessage::User {
2091                content: open_ai::MessageContent::Plain(prompt),
2092            }],
2093            stream: false,
2094            max_completion_tokens: None,
2095            stop: Default::default(),
2096            temperature: 0.7,
2097            tool_choice: None,
2098            parallel_tool_calls: None,
2099            tools: vec![open_ai::ToolDefinition::Function {
2100                function: FunctionDefinition {
2101                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2102                    description: Some(tool_description),
2103                    parameters: Some(tool_schema),
2104                },
2105            }],
2106            prompt_cache_key: None,
2107            reasoning_effort: None,
2108        };
2109
2110        #[cfg(feature = "eval-support")]
2111        let eval_cache = self.eval_cache.clone();
2112
2113        cx.spawn(async move |this, cx| {
2114            log::trace!("Sending search planning request");
2115            let response = Self::send_raw_llm_request(
2116                request,
2117                client,
2118                llm_token,
2119                app_version,
2120                #[cfg(feature = "eval-support")]
2121                eval_cache.clone(),
2122                #[cfg(feature = "eval-support")]
2123                EvalCacheEntryKind::Context,
2124            )
2125            .await;
2126            let mut response = Self::handle_api_response(&this, response, cx)?;
2127            log::trace!("Got search planning response");
2128
2129            let choice = response
2130                .choices
2131                .pop()
2132                .context("No choices in retrieval response")?;
2133            let open_ai::RequestMessage::Assistant {
2134                content: _,
2135                tool_calls,
2136            } = choice.message
2137            else {
2138                anyhow::bail!("Retrieval response didn't include an assistant message");
2139            };
2140
2141            let mut queries: Vec<SearchToolQuery> = Vec::new();
2142            for tool_call in tool_calls {
2143                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2144                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2145                    log::warn!(
2146                        "Context retrieval response tried to call an unknown tool: {}",
2147                        function.name
2148                    );
2149
2150                    continue;
2151                }
2152
2153                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2154                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2155                queries.extend(input.queries);
2156            }
2157
2158            if let Some(debug_tx) = &debug_tx {
2159                debug_tx
2160                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2161                        ZetaSearchQueryDebugInfo {
2162                            project: project.clone(),
2163                            timestamp: Instant::now(),
2164                            search_queries: queries.clone(),
2165                        },
2166                    ))
2167                    .ok();
2168            }
2169
2170            log::trace!("Running retrieval search: {queries:#?}");
2171
2172            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2173                queries,
2174                project.clone(),
2175                #[cfg(feature = "eval-support")]
2176                eval_cache,
2177                cx,
2178            )
2179            .await;
2180
2181            log::trace!("Search queries executed");
2182
2183            if let Some(debug_tx) = &debug_tx {
2184                debug_tx
2185                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2186                        ZetaContextRetrievalDebugInfo {
2187                            project: project.clone(),
2188                            timestamp: Instant::now(),
2189                        },
2190                    ))
2191                    .ok();
2192            }
2193
2194            this.update(cx, |this, _cx| {
2195                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2196                    return Ok(());
2197                };
2198                zeta_project.refresh_context_task.take();
2199                if let Some(debug_tx) = &this.debug_tx {
2200                    debug_tx
2201                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2202                            ZetaContextRetrievalDebugInfo {
2203                                project,
2204                                timestamp: Instant::now(),
2205                            },
2206                        ))
2207                        .ok();
2208                }
2209                match related_excerpts_result {
2210                    Ok(excerpts) => {
2211                        zeta_project.context = Some(excerpts);
2212                        Ok(())
2213                    }
2214                    Err(error) => Err(error),
2215                }
2216            })?
2217        })
2218    }
2219
2220    pub fn set_context(
2221        &mut self,
2222        project: Entity<Project>,
2223        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2224    ) {
2225        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2226            zeta_project.context = Some(context);
2227        }
2228    }
2229
2230    fn gather_nearby_diagnostics(
2231        cursor_offset: usize,
2232        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2233        snapshot: &BufferSnapshot,
2234        max_diagnostics_bytes: usize,
2235    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2236        // TODO: Could make this more efficient
2237        let mut diagnostic_groups = Vec::new();
2238        for (language_server_id, diagnostics) in diagnostic_sets {
2239            let mut groups = Vec::new();
2240            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2241            diagnostic_groups.extend(
2242                groups
2243                    .into_iter()
2244                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2245            );
2246        }
2247
2248        // sort by proximity to cursor
2249        diagnostic_groups.sort_by_key(|group| {
2250            let range = &group.entries[group.primary_ix].range;
2251            if range.start >= cursor_offset {
2252                range.start - cursor_offset
2253            } else if cursor_offset >= range.end {
2254                cursor_offset - range.end
2255            } else {
2256                (cursor_offset - range.start).min(range.end - cursor_offset)
2257            }
2258        });
2259
2260        let mut results = Vec::new();
2261        let mut diagnostic_groups_truncated = false;
2262        let mut diagnostics_byte_count = 0;
2263        for group in diagnostic_groups {
2264            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2265            diagnostics_byte_count += raw_value.get().len();
2266            if diagnostics_byte_count > max_diagnostics_bytes {
2267                diagnostic_groups_truncated = true;
2268                break;
2269            }
2270            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2271        }
2272
2273        (results, diagnostic_groups_truncated)
2274    }
2275
2276    // TODO: Dedupe with similar code in request_prediction?
2277    pub fn cloud_request_for_zeta_cli(
2278        &mut self,
2279        project: &Entity<Project>,
2280        buffer: &Entity<Buffer>,
2281        position: language::Anchor,
2282        cx: &mut Context<Self>,
2283    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2284        let project_state = self.projects.get(&project.entity_id());
2285
2286        let index_state = project_state.and_then(|state| {
2287            state
2288                .syntax_index
2289                .as_ref()
2290                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2291        });
2292        let options = self.options.clone();
2293        let snapshot = buffer.read(cx).snapshot();
2294        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2295            return Task::ready(Err(anyhow!("No file path for excerpt")));
2296        };
2297        let worktree_snapshots = project
2298            .read(cx)
2299            .worktrees(cx)
2300            .map(|worktree| worktree.read(cx).snapshot())
2301            .collect::<Vec<_>>();
2302
2303        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2304            let mut path = f.worktree.read(cx).absolutize(&f.path);
2305            if path.pop() { Some(path) } else { None }
2306        });
2307
2308        cx.background_spawn(async move {
2309            let index_state = if let Some(index_state) = index_state {
2310                Some(index_state.lock_owned().await)
2311            } else {
2312                None
2313            };
2314
2315            let cursor_point = position.to_point(&snapshot);
2316
2317            let debug_info = true;
2318            EditPredictionContext::gather_context(
2319                cursor_point,
2320                &snapshot,
2321                parent_abs_path.as_deref(),
2322                match &options.context {
2323                    ContextMode::Agentic(_) => {
2324                        // TODO
2325                        panic!("Llm mode not supported in zeta cli yet");
2326                    }
2327                    ContextMode::Syntax(edit_prediction_context_options) => {
2328                        edit_prediction_context_options
2329                    }
2330                },
2331                index_state.as_deref(),
2332            )
2333            .context("Failed to select excerpt")
2334            .map(|context| {
2335                make_syntax_context_cloud_request(
2336                    excerpt_path.into(),
2337                    context,
2338                    // TODO pass everything
2339                    Vec::new(),
2340                    false,
2341                    Vec::new(),
2342                    false,
2343                    None,
2344                    debug_info,
2345                    &worktree_snapshots,
2346                    index_state.as_deref(),
2347                    Some(options.max_prompt_bytes),
2348                    options.prompt_format,
2349                )
2350            })
2351        })
2352    }
2353
2354    pub fn wait_for_initial_indexing(
2355        &mut self,
2356        project: &Entity<Project>,
2357        cx: &mut Context<Self>,
2358    ) -> Task<Result<()>> {
2359        let zeta_project = self.get_or_init_zeta_project(project, cx);
2360        if let Some(syntax_index) = &zeta_project.syntax_index {
2361            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2362        } else {
2363            Task::ready(Ok(()))
2364        }
2365    }
2366
2367    fn is_file_open_source(
2368        &self,
2369        project: &Entity<Project>,
2370        file: &Arc<dyn File>,
2371        cx: &App,
2372    ) -> bool {
2373        if !file.is_local() || file.is_private() {
2374            return false;
2375        }
2376        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2377            return false;
2378        };
2379        zeta_project
2380            .license_detection_watchers
2381            .get(&file.worktree_id(cx))
2382            .as_ref()
2383            .is_some_and(|watcher| watcher.is_project_open_source())
2384    }
2385
2386    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2387        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2388    }
2389
2390    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2391        if !self.data_collection_choice.is_enabled() {
2392            return false;
2393        }
2394        events.iter().all(|event| {
2395            matches!(
2396                event.as_ref(),
2397                Event::BufferChange {
2398                    in_open_source_repo: true,
2399                    ..
2400                }
2401            )
2402        })
2403    }
2404
2405    fn load_data_collection_choice() -> DataCollectionChoice {
2406        let choice = KEY_VALUE_STORE
2407            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2408            .log_err()
2409            .flatten();
2410
2411        match choice.as_deref() {
2412            Some("true") => DataCollectionChoice::Enabled,
2413            Some("false") => DataCollectionChoice::Disabled,
2414            Some(_) => {
2415                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2416                DataCollectionChoice::NotAnswered
2417            }
2418            None => DataCollectionChoice::NotAnswered,
2419        }
2420    }
2421
2422    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2423        self.shown_predictions.iter()
2424    }
2425
2426    pub fn shown_completions_len(&self) -> usize {
2427        self.shown_predictions.len()
2428    }
2429
2430    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2431        self.rated_predictions.contains(id)
2432    }
2433
2434    pub fn rate_prediction(
2435        &mut self,
2436        prediction: &EditPrediction,
2437        rating: EditPredictionRating,
2438        feedback: String,
2439        cx: &mut Context<Self>,
2440    ) {
2441        self.rated_predictions.insert(prediction.id.clone());
2442        telemetry::event!(
2443            "Edit Prediction Rated",
2444            rating,
2445            inputs = prediction.inputs,
2446            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2447            feedback
2448        );
2449        self.client.telemetry().flush_events().detach();
2450        cx.notify();
2451    }
2452}
2453
2454pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2455    let choice = res.choices.pop()?;
2456    let output_text = match choice.message {
2457        open_ai::RequestMessage::Assistant {
2458            content: Some(open_ai::MessageContent::Plain(content)),
2459            ..
2460        } => content,
2461        open_ai::RequestMessage::Assistant {
2462            content: Some(open_ai::MessageContent::Multipart(mut content)),
2463            ..
2464        } => {
2465            if content.is_empty() {
2466                log::error!("No output from Baseten completion response");
2467                return None;
2468            }
2469
2470            match content.remove(0) {
2471                open_ai::MessagePart::Text { text } => text,
2472                open_ai::MessagePart::Image { .. } => {
2473                    log::error!("Expected text, got an image");
2474                    return None;
2475                }
2476            }
2477        }
2478        _ => {
2479            log::error!("Invalid response message: {:?}", choice.message);
2480            return None;
2481        }
2482    };
2483    Some(output_text)
2484}
2485
2486#[derive(Error, Debug)]
2487#[error(
2488    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2489)]
2490pub struct ZedUpdateRequiredError {
2491    minimum_version: Version,
2492}
2493
2494fn make_syntax_context_cloud_request(
2495    excerpt_path: Arc<Path>,
2496    context: EditPredictionContext,
2497    events: Vec<Arc<predict_edits_v3::Event>>,
2498    can_collect_data: bool,
2499    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2500    diagnostic_groups_truncated: bool,
2501    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2502    debug_info: bool,
2503    worktrees: &Vec<worktree::Snapshot>,
2504    index_state: Option<&SyntaxIndexState>,
2505    prompt_max_bytes: Option<usize>,
2506    prompt_format: PromptFormat,
2507) -> predict_edits_v3::PredictEditsRequest {
2508    let mut signatures = Vec::new();
2509    let mut declaration_to_signature_index = HashMap::default();
2510    let mut referenced_declarations = Vec::new();
2511
2512    for snippet in context.declarations {
2513        let project_entry_id = snippet.declaration.project_entry_id();
2514        let Some(path) = worktrees.iter().find_map(|worktree| {
2515            worktree.entry_for_id(project_entry_id).map(|entry| {
2516                let mut full_path = RelPathBuf::new();
2517                full_path.push(worktree.root_name());
2518                full_path.push(&entry.path);
2519                full_path
2520            })
2521        }) else {
2522            continue;
2523        };
2524
2525        let parent_index = index_state.and_then(|index_state| {
2526            snippet.declaration.parent().and_then(|parent| {
2527                add_signature(
2528                    parent,
2529                    &mut declaration_to_signature_index,
2530                    &mut signatures,
2531                    index_state,
2532                )
2533            })
2534        });
2535
2536        let (text, text_is_truncated) = snippet.declaration.item_text();
2537        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2538            path: path.as_std_path().into(),
2539            text: text.into(),
2540            range: snippet.declaration.item_line_range(),
2541            text_is_truncated,
2542            signature_range: snippet.declaration.signature_range_in_item_text(),
2543            parent_index,
2544            signature_score: snippet.score(DeclarationStyle::Signature),
2545            declaration_score: snippet.score(DeclarationStyle::Declaration),
2546            score_components: snippet.components,
2547        });
2548    }
2549
2550    let excerpt_parent = index_state.and_then(|index_state| {
2551        context
2552            .excerpt
2553            .parent_declarations
2554            .last()
2555            .and_then(|(parent, _)| {
2556                add_signature(
2557                    *parent,
2558                    &mut declaration_to_signature_index,
2559                    &mut signatures,
2560                    index_state,
2561                )
2562            })
2563    });
2564
2565    predict_edits_v3::PredictEditsRequest {
2566        excerpt_path,
2567        excerpt: context.excerpt_text.body,
2568        excerpt_line_range: context.excerpt.line_range,
2569        excerpt_range: context.excerpt.range,
2570        cursor_point: predict_edits_v3::Point {
2571            line: predict_edits_v3::Line(context.cursor_point.row),
2572            column: context.cursor_point.column,
2573        },
2574        referenced_declarations,
2575        included_files: vec![],
2576        signatures,
2577        excerpt_parent,
2578        events,
2579        can_collect_data,
2580        diagnostic_groups,
2581        diagnostic_groups_truncated,
2582        git_info,
2583        debug_info,
2584        prompt_max_bytes,
2585        prompt_format,
2586    }
2587}
2588
2589fn add_signature(
2590    declaration_id: DeclarationId,
2591    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2592    signatures: &mut Vec<Signature>,
2593    index: &SyntaxIndexState,
2594) -> Option<usize> {
2595    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2596        return Some(*signature_index);
2597    }
2598    let Some(parent_declaration) = index.declaration(declaration_id) else {
2599        log::error!("bug: missing parent declaration");
2600        return None;
2601    };
2602    let parent_index = parent_declaration.parent().and_then(|parent| {
2603        add_signature(parent, declaration_to_signature_index, signatures, index)
2604    });
2605    let (text, text_is_truncated) = parent_declaration.signature_text();
2606    let signature_index = signatures.len();
2607    signatures.push(Signature {
2608        text: text.into(),
2609        text_is_truncated,
2610        parent_index,
2611        range: parent_declaration.signature_line_range(),
2612    });
2613    declaration_to_signature_index.insert(declaration_id, signature_index);
2614    Some(signature_index)
2615}
2616
2617#[cfg(feature = "eval-support")]
2618pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2619
2620#[cfg(feature = "eval-support")]
2621#[derive(Debug, Clone, Copy, PartialEq)]
2622pub enum EvalCacheEntryKind {
2623    Context,
2624    Search,
2625    Prediction,
2626}
2627
2628#[cfg(feature = "eval-support")]
2629impl std::fmt::Display for EvalCacheEntryKind {
2630    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2631        match self {
2632            EvalCacheEntryKind::Search => write!(f, "search"),
2633            EvalCacheEntryKind::Context => write!(f, "context"),
2634            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2635        }
2636    }
2637}
2638
2639#[cfg(feature = "eval-support")]
2640pub trait EvalCache: Send + Sync {
2641    fn read(&self, key: EvalCacheKey) -> Option<String>;
2642    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2643}
2644
2645#[derive(Debug, Clone, Copy)]
2646pub enum DataCollectionChoice {
2647    NotAnswered,
2648    Enabled,
2649    Disabled,
2650}
2651
2652impl DataCollectionChoice {
2653    pub fn is_enabled(self) -> bool {
2654        match self {
2655            Self::Enabled => true,
2656            Self::NotAnswered | Self::Disabled => false,
2657        }
2658    }
2659
2660    pub fn is_answered(self) -> bool {
2661        match self {
2662            Self::Enabled | Self::Disabled => true,
2663            Self::NotAnswered => false,
2664        }
2665    }
2666
2667    #[must_use]
2668    pub fn toggle(&self) -> DataCollectionChoice {
2669        match self {
2670            Self::Enabled => Self::Disabled,
2671            Self::Disabled => Self::Enabled,
2672            Self::NotAnswered => Self::Enabled,
2673        }
2674    }
2675}
2676
2677impl From<bool> for DataCollectionChoice {
2678    fn from(value: bool) -> Self {
2679        match value {
2680            true => DataCollectionChoice::Enabled,
2681            false => DataCollectionChoice::Disabled,
2682        }
2683    }
2684}
2685
2686struct ZedPredictUpsell;
2687
2688impl Dismissable for ZedPredictUpsell {
2689    const KEY: &'static str = "dismissed-edit-predict-upsell";
2690
2691    fn dismissed() -> bool {
2692        // To make this backwards compatible with older versions of Zed, we
2693        // check if the user has seen the previous Edit Prediction Onboarding
2694        // before, by checking the data collection choice which was written to
2695        // the database once the user clicked on "Accept and Enable"
2696        if KEY_VALUE_STORE
2697            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2698            .log_err()
2699            .is_some_and(|s| s.is_some())
2700        {
2701            return true;
2702        }
2703
2704        KEY_VALUE_STORE
2705            .read_kvp(Self::KEY)
2706            .log_err()
2707            .is_some_and(|s| s.is_some())
2708    }
2709}
2710
2711pub fn should_show_upsell_modal() -> bool {
2712    !ZedPredictUpsell::dismissed()
2713}
2714
2715pub fn init(cx: &mut App) {
2716    feature_gate_predict_edits_actions(cx);
2717
2718    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2719        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2720            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2721                RatePredictionsModal::toggle(workspace, window, cx);
2722            }
2723        });
2724
2725        workspace.register_action(
2726            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2727                ZedPredictModal::toggle(
2728                    workspace,
2729                    workspace.user_store().clone(),
2730                    workspace.client().clone(),
2731                    window,
2732                    cx,
2733                )
2734            },
2735        );
2736
2737        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2738            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2739                settings
2740                    .project
2741                    .all_languages
2742                    .features
2743                    .get_or_insert_default()
2744                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2745            });
2746        });
2747    })
2748    .detach();
2749}
2750
2751fn feature_gate_predict_edits_actions(cx: &mut App) {
2752    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2753    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2754    let zeta_all_action_types = [
2755        TypeId::of::<RateCompletions>(),
2756        TypeId::of::<ResetOnboarding>(),
2757        zed_actions::OpenZedPredictOnboarding.type_id(),
2758        TypeId::of::<ClearHistory>(),
2759        TypeId::of::<ThumbsUpActivePrediction>(),
2760        TypeId::of::<ThumbsDownActivePrediction>(),
2761        TypeId::of::<NextEdit>(),
2762        TypeId::of::<PreviousEdit>(),
2763    ];
2764
2765    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2766        filter.hide_action_types(&rate_completion_action_types);
2767        filter.hide_action_types(&reset_onboarding_action_types);
2768        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2769    });
2770
2771    cx.observe_global::<SettingsStore>(move |cx| {
2772        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2773        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2774
2775        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2776            if is_ai_disabled {
2777                filter.hide_action_types(&zeta_all_action_types);
2778            } else if has_feature_flag {
2779                filter.show_action_types(&rate_completion_action_types);
2780            } else {
2781                filter.hide_action_types(&rate_completion_action_types);
2782            }
2783        });
2784    })
2785    .detach();
2786
2787    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2788        if !DisableAiSettings::get_global(cx).disable_ai {
2789            if is_enabled {
2790                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2791                    filter.show_action_types(&rate_completion_action_types);
2792                });
2793            } else {
2794                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2795                    filter.hide_action_types(&rate_completion_action_types);
2796                });
2797            }
2798        }
2799    })
2800    .detach();
2801}
2802
2803#[cfg(test)]
2804mod tests {
2805    use std::{path::Path, sync::Arc};
2806
2807    use client::UserStore;
2808    use clock::FakeSystemClock;
2809    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2810    use futures::{
2811        AsyncReadExt, StreamExt,
2812        channel::{mpsc, oneshot},
2813    };
2814    use gpui::{
2815        Entity, TestAppContext,
2816        http_client::{FakeHttpClient, Response},
2817        prelude::*,
2818    };
2819    use indoc::indoc;
2820    use language::OffsetRangeExt as _;
2821    use open_ai::Usage;
2822    use pretty_assertions::{assert_eq, assert_matches};
2823    use project::{FakeFs, Project};
2824    use serde_json::json;
2825    use settings::SettingsStore;
2826    use util::path;
2827    use uuid::Uuid;
2828
2829    use crate::{BufferEditPrediction, Zeta};
2830
2831    #[gpui::test]
2832    async fn test_current_state(cx: &mut TestAppContext) {
2833        let (zeta, mut req_rx) = init_test(cx);
2834        let fs = FakeFs::new(cx.executor());
2835        fs.insert_tree(
2836            "/root",
2837            json!({
2838                "1.txt": "Hello!\nHow\nBye\n",
2839                "2.txt": "Hola!\nComo\nAdios\n"
2840            }),
2841        )
2842        .await;
2843        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2844
2845        zeta.update(cx, |zeta, cx| {
2846            zeta.register_project(&project, cx);
2847        });
2848
2849        let buffer1 = project
2850            .update(cx, |project, cx| {
2851                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2852                project.open_buffer(path, cx)
2853            })
2854            .await
2855            .unwrap();
2856        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2857        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2858
2859        // Prediction for current file
2860
2861        zeta.update(cx, |zeta, cx| {
2862            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2863        });
2864        let (_request, respond_tx) = req_rx.next().await.unwrap();
2865
2866        respond_tx
2867            .send(model_response(indoc! {r"
2868                --- a/root/1.txt
2869                +++ b/root/1.txt
2870                @@ ... @@
2871                 Hello!
2872                -How
2873                +How are you?
2874                 Bye
2875            "}))
2876            .unwrap();
2877
2878        cx.run_until_parked();
2879
2880        zeta.read_with(cx, |zeta, cx| {
2881            let prediction = zeta
2882                .current_prediction_for_buffer(&buffer1, &project, cx)
2883                .unwrap();
2884            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2885        });
2886
2887        // Context refresh
2888        let refresh_task = zeta.update(cx, |zeta, cx| {
2889            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2890        });
2891        let (_request, respond_tx) = req_rx.next().await.unwrap();
2892        respond_tx
2893            .send(open_ai::Response {
2894                id: Uuid::new_v4().to_string(),
2895                object: "response".into(),
2896                created: 0,
2897                model: "model".into(),
2898                choices: vec![open_ai::Choice {
2899                    index: 0,
2900                    message: open_ai::RequestMessage::Assistant {
2901                        content: None,
2902                        tool_calls: vec![open_ai::ToolCall {
2903                            id: "search".into(),
2904                            content: open_ai::ToolCallContent::Function {
2905                                function: open_ai::FunctionContent {
2906                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2907                                        .to_string(),
2908                                    arguments: serde_json::to_string(&SearchToolInput {
2909                                        queries: Box::new([SearchToolQuery {
2910                                            glob: "root/2.txt".to_string(),
2911                                            syntax_node: vec![],
2912                                            content: Some(".".into()),
2913                                        }]),
2914                                    })
2915                                    .unwrap(),
2916                                },
2917                            },
2918                        }],
2919                    },
2920                    finish_reason: None,
2921                }],
2922                usage: Usage {
2923                    prompt_tokens: 0,
2924                    completion_tokens: 0,
2925                    total_tokens: 0,
2926                },
2927            })
2928            .unwrap();
2929        refresh_task.await.unwrap();
2930
2931        zeta.update(cx, |zeta, cx| {
2932            zeta.discard_current_prediction(&project, cx);
2933        });
2934
2935        // Prediction for another file
2936        zeta.update(cx, |zeta, cx| {
2937            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2938        });
2939        let (_request, respond_tx) = req_rx.next().await.unwrap();
2940        respond_tx
2941            .send(model_response(indoc! {r#"
2942                --- a/root/2.txt
2943                +++ b/root/2.txt
2944                 Hola!
2945                -Como
2946                +Como estas?
2947                 Adios
2948            "#}))
2949            .unwrap();
2950        cx.run_until_parked();
2951
2952        zeta.read_with(cx, |zeta, cx| {
2953            let prediction = zeta
2954                .current_prediction_for_buffer(&buffer1, &project, cx)
2955                .unwrap();
2956            assert_matches!(
2957                prediction,
2958                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2959            );
2960        });
2961
2962        let buffer2 = project
2963            .update(cx, |project, cx| {
2964                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2965                project.open_buffer(path, cx)
2966            })
2967            .await
2968            .unwrap();
2969
2970        zeta.read_with(cx, |zeta, cx| {
2971            let prediction = zeta
2972                .current_prediction_for_buffer(&buffer2, &project, cx)
2973                .unwrap();
2974            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2975        });
2976    }
2977
2978    #[gpui::test]
2979    async fn test_simple_request(cx: &mut TestAppContext) {
2980        let (zeta, mut req_rx) = init_test(cx);
2981        let fs = FakeFs::new(cx.executor());
2982        fs.insert_tree(
2983            "/root",
2984            json!({
2985                "foo.md":  "Hello!\nHow\nBye\n"
2986            }),
2987        )
2988        .await;
2989        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2990
2991        let buffer = project
2992            .update(cx, |project, cx| {
2993                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2994                project.open_buffer(path, cx)
2995            })
2996            .await
2997            .unwrap();
2998        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2999        let position = snapshot.anchor_before(language::Point::new(1, 3));
3000
3001        let prediction_task = zeta.update(cx, |zeta, cx| {
3002            zeta.request_prediction(&project, &buffer, position, cx)
3003        });
3004
3005        let (_, respond_tx) = req_rx.next().await.unwrap();
3006
3007        // TODO Put back when we have a structured request again
3008        // assert_eq!(
3009        //     request.excerpt_path.as_ref(),
3010        //     Path::new(path!("root/foo.md"))
3011        // );
3012        // assert_eq!(
3013        //     request.cursor_point,
3014        //     Point {
3015        //         line: Line(1),
3016        //         column: 3
3017        //     }
3018        // );
3019
3020        respond_tx
3021            .send(model_response(indoc! { r"
3022                --- a/root/foo.md
3023                +++ b/root/foo.md
3024                @@ ... @@
3025                 Hello!
3026                -How
3027                +How are you?
3028                 Bye
3029            "}))
3030            .unwrap();
3031
3032        let prediction = prediction_task.await.unwrap().unwrap();
3033
3034        assert_eq!(prediction.edits.len(), 1);
3035        assert_eq!(
3036            prediction.edits[0].0.to_point(&snapshot).start,
3037            language::Point::new(1, 3)
3038        );
3039        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3040    }
3041
3042    #[gpui::test]
3043    async fn test_request_events(cx: &mut TestAppContext) {
3044        let (zeta, mut req_rx) = init_test(cx);
3045        let fs = FakeFs::new(cx.executor());
3046        fs.insert_tree(
3047            "/root",
3048            json!({
3049                "foo.md": "Hello!\n\nBye\n"
3050            }),
3051        )
3052        .await;
3053        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3054
3055        let buffer = project
3056            .update(cx, |project, cx| {
3057                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3058                project.open_buffer(path, cx)
3059            })
3060            .await
3061            .unwrap();
3062
3063        zeta.update(cx, |zeta, cx| {
3064            zeta.register_buffer(&buffer, &project, cx);
3065        });
3066
3067        buffer.update(cx, |buffer, cx| {
3068            buffer.edit(vec![(7..7, "How")], None, cx);
3069        });
3070
3071        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3072        let position = snapshot.anchor_before(language::Point::new(1, 3));
3073
3074        let prediction_task = zeta.update(cx, |zeta, cx| {
3075            zeta.request_prediction(&project, &buffer, position, cx)
3076        });
3077
3078        let (request, respond_tx) = req_rx.next().await.unwrap();
3079
3080        let prompt = prompt_from_request(&request);
3081        assert!(
3082            prompt.contains(indoc! {"
3083            --- a/root/foo.md
3084            +++ b/root/foo.md
3085            @@ -1,3 +1,3 @@
3086             Hello!
3087            -
3088            +How
3089             Bye
3090        "}),
3091            "{prompt}"
3092        );
3093
3094        respond_tx
3095            .send(model_response(indoc! {r#"
3096                --- a/root/foo.md
3097                +++ b/root/foo.md
3098                @@ ... @@
3099                 Hello!
3100                -How
3101                +How are you?
3102                 Bye
3103            "#}))
3104            .unwrap();
3105
3106        let prediction = prediction_task.await.unwrap().unwrap();
3107
3108        assert_eq!(prediction.edits.len(), 1);
3109        assert_eq!(
3110            prediction.edits[0].0.to_point(&snapshot).start,
3111            language::Point::new(1, 3)
3112        );
3113        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3114    }
3115
3116    // Skipped until we start including diagnostics in prompt
3117    // #[gpui::test]
3118    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3119    //     let (zeta, mut req_rx) = init_test(cx);
3120    //     let fs = FakeFs::new(cx.executor());
3121    //     fs.insert_tree(
3122    //         "/root",
3123    //         json!({
3124    //             "foo.md": "Hello!\nBye"
3125    //         }),
3126    //     )
3127    //     .await;
3128    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3129
3130    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3131    //     let diagnostic = lsp::Diagnostic {
3132    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3133    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3134    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3135    //         ..Default::default()
3136    //     };
3137
3138    //     project.update(cx, |project, cx| {
3139    //         project.lsp_store().update(cx, |lsp_store, cx| {
3140    //             // Create some diagnostics
3141    //             lsp_store
3142    //                 .update_diagnostics(
3143    //                     LanguageServerId(0),
3144    //                     lsp::PublishDiagnosticsParams {
3145    //                         uri: path_to_buffer_uri.clone(),
3146    //                         diagnostics: vec![diagnostic],
3147    //                         version: None,
3148    //                     },
3149    //                     None,
3150    //                     language::DiagnosticSourceKind::Pushed,
3151    //                     &[],
3152    //                     cx,
3153    //                 )
3154    //                 .unwrap();
3155    //         });
3156    //     });
3157
3158    //     let buffer = project
3159    //         .update(cx, |project, cx| {
3160    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3161    //             project.open_buffer(path, cx)
3162    //         })
3163    //         .await
3164    //         .unwrap();
3165
3166    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3167    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3168
3169    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3170    //         zeta.request_prediction(&project, &buffer, position, cx)
3171    //     });
3172
3173    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3174
3175    //     assert_eq!(request.diagnostic_groups.len(), 1);
3176    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3177    //         .unwrap();
3178    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3179    //     assert_eq!(
3180    //         value,
3181    //         json!({
3182    //             "entries": [{
3183    //                 "range": {
3184    //                     "start": 8,
3185    //                     "end": 10
3186    //                 },
3187    //                 "diagnostic": {
3188    //                     "source": null,
3189    //                     "code": null,
3190    //                     "code_description": null,
3191    //                     "severity": 1,
3192    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3193    //                     "markdown": null,
3194    //                     "group_id": 0,
3195    //                     "is_primary": true,
3196    //                     "is_disk_based": false,
3197    //                     "is_unnecessary": false,
3198    //                     "source_kind": "Pushed",
3199    //                     "data": null,
3200    //                     "underline": true
3201    //                 }
3202    //             }],
3203    //             "primary_ix": 0
3204    //         })
3205    //     );
3206    // }
3207
3208    fn model_response(text: &str) -> open_ai::Response {
3209        open_ai::Response {
3210            id: Uuid::new_v4().to_string(),
3211            object: "response".into(),
3212            created: 0,
3213            model: "model".into(),
3214            choices: vec![open_ai::Choice {
3215                index: 0,
3216                message: open_ai::RequestMessage::Assistant {
3217                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3218                    tool_calls: vec![],
3219                },
3220                finish_reason: None,
3221            }],
3222            usage: Usage {
3223                prompt_tokens: 0,
3224                completion_tokens: 0,
3225                total_tokens: 0,
3226            },
3227        }
3228    }
3229
3230    fn prompt_from_request(request: &open_ai::Request) -> &str {
3231        assert_eq!(request.messages.len(), 1);
3232        let open_ai::RequestMessage::User {
3233            content: open_ai::MessageContent::Plain(content),
3234            ..
3235        } = &request.messages[0]
3236        else {
3237            panic!(
3238                "Request does not have single user message of type Plain. {:#?}",
3239                request
3240            );
3241        };
3242        content
3243    }
3244
3245    fn init_test(
3246        cx: &mut TestAppContext,
3247    ) -> (
3248        Entity<Zeta>,
3249        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3250    ) {
3251        cx.update(move |cx| {
3252            let settings_store = SettingsStore::test(cx);
3253            cx.set_global(settings_store);
3254            zlog::init_test();
3255
3256            let (req_tx, req_rx) = mpsc::unbounded();
3257
3258            let http_client = FakeHttpClient::create({
3259                move |req| {
3260                    let uri = req.uri().path().to_string();
3261                    let mut body = req.into_body();
3262                    let req_tx = req_tx.clone();
3263                    async move {
3264                        let resp = match uri.as_str() {
3265                            "/client/llm_tokens" => serde_json::to_string(&json!({
3266                                "token": "test"
3267                            }))
3268                            .unwrap(),
3269                            "/predict_edits/raw" => {
3270                                let mut buf = Vec::new();
3271                                body.read_to_end(&mut buf).await.ok();
3272                                let req = serde_json::from_slice(&buf).unwrap();
3273
3274                                let (res_tx, res_rx) = oneshot::channel();
3275                                req_tx.unbounded_send((req, res_tx)).unwrap();
3276                                serde_json::to_string(&res_rx.await?).unwrap()
3277                            }
3278                            _ => {
3279                                panic!("Unexpected path: {}", uri)
3280                            }
3281                        };
3282
3283                        Ok(Response::builder().body(resp.into()).unwrap())
3284                    }
3285                }
3286            });
3287
3288            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3289            client.cloud_client().set_credentials(1, "test".into());
3290
3291            language_model::init(client.clone(), cx);
3292
3293            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3294            let zeta = Zeta::global(&client, &user_store, cx);
3295
3296            (zeta, req_rx)
3297        })
3298    }
3299}