zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
   7    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::{HashMap, HashSet};
  13use command_palette_hooks::CommandPaletteFilter;
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{
  16    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  17    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  18    SyntaxIndex, SyntaxIndexState,
  19};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  21use futures::channel::{mpsc, oneshot};
  22use futures::{AsyncReadExt as _, StreamExt as _};
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::{
  29    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  30};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, RefreshLlmTokenListener};
  33use open_ai::FunctionDefinition;
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  39use std::any::{Any as _, TypeId};
  40use std::collections::{VecDeque, hash_map};
  41use telemetry_events::EditPredictionRating;
  42use workspace::Workspace;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::rel_path::RelPathBuf;
  53use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  55
  56pub mod assemble_excerpts;
  57mod license_detection;
  58mod onboarding_modal;
  59mod prediction;
  60mod provider;
  61mod rate_prediction_modal;
  62pub mod retrieval_search;
  63mod sweep_ai;
  64pub mod udiff;
  65mod xml_edits;
  66pub mod zeta1;
  67
  68#[cfg(test)]
  69mod zeta_tests;
  70
  71use crate::assemble_excerpts::assemble_excerpts;
  72use crate::license_detection::LicenseDetectionWatcher;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76pub use crate::prediction::EditPredictionInputs;
  77use crate::rate_prediction_modal::{
  78    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  79    ThumbsUpActivePrediction,
  80};
  81use crate::sweep_ai::SweepAi;
  82use crate::zeta1::request_prediction_with_zeta1;
  83pub use provider::ZetaEditPredictionProvider;
  84
  85actions!(
  86    edit_prediction,
  87    [
  88        /// Resets the edit prediction onboarding state.
  89        ResetOnboarding,
  90        /// Opens the rate completions modal.
  91        RateCompletions,
  92        /// Clears the edit prediction history.
  93        ClearHistory,
  94    ]
  95);
  96
  97/// Maximum number of events to track.
  98const EVENT_COUNT_MAX: usize = 6;
  99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 100const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 101
 102pub struct SweepFeatureFlag;
 103
 104impl FeatureFlag for SweepFeatureFlag {
 105    const NAME: &str = "sweep-ai";
 106}
 107pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 108    max_bytes: 512,
 109    min_bytes: 128,
 110    target_before_cursor_over_total_bytes: 0.5,
 111};
 112
 113pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 114    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 115
 116pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 117    excerpt: DEFAULT_EXCERPT_OPTIONS,
 118};
 119
 120pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 121    EditPredictionContextOptions {
 122        use_imports: true,
 123        max_retrieved_declarations: 0,
 124        excerpt: DEFAULT_EXCERPT_OPTIONS,
 125        score: EditPredictionScoreOptions {
 126            omit_excerpt_overlaps: true,
 127        },
 128    };
 129
 130pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 131    context: DEFAULT_CONTEXT_OPTIONS,
 132    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 133    max_diagnostic_bytes: 2048,
 134    prompt_format: PromptFormat::DEFAULT,
 135    file_indexing_parallelism: 1,
 136    buffer_change_grouping_interval: Duration::from_secs(1),
 137};
 138
 139static USE_OLLAMA: LazyLock<bool> =
 140    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 141static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 142    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 143        "qwen3-coder:30b".to_string()
 144    } else {
 145        "yqvev8r3".to_string()
 146    })
 147});
 148static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 149    match env::var("ZED_ZETA2_MODEL").as_deref() {
 150        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 151        Ok(model) => model,
 152        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 153        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 154    }
 155    .to_string()
 156});
 157static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 158    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 159        if *USE_OLLAMA {
 160            Some("http://localhost:11434/v1/chat/completions".into())
 161        } else {
 162            None
 163        }
 164    })
 165});
 166
 167pub struct Zeta2FeatureFlag;
 168
 169impl FeatureFlag for Zeta2FeatureFlag {
 170    const NAME: &'static str = "zeta2";
 171
 172    fn enabled_for_staff() -> bool {
 173        true
 174    }
 175}
 176
 177#[derive(Clone)]
 178struct ZetaGlobal(Entity<Zeta>);
 179
 180impl Global for ZetaGlobal {}
 181
 182pub struct Zeta {
 183    client: Arc<Client>,
 184    user_store: Entity<UserStore>,
 185    llm_token: LlmApiToken,
 186    _llm_token_subscription: Subscription,
 187    projects: HashMap<EntityId, ZetaProject>,
 188    options: ZetaOptions,
 189    update_required: bool,
 190    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 191    #[cfg(feature = "eval-support")]
 192    eval_cache: Option<Arc<dyn EvalCache>>,
 193    edit_prediction_model: ZetaEditPredictionModel,
 194    sweep_ai: SweepAi,
 195    data_collection_choice: DataCollectionChoice,
 196    rejected_predictions: Vec<EditPredictionRejection>,
 197    reject_predictions_tx: mpsc::UnboundedSender<()>,
 198    reject_predictions_debounce_task: Option<Task<()>>,
 199    shown_predictions: VecDeque<EditPrediction>,
 200    rated_predictions: HashSet<EditPredictionId>,
 201}
 202
 203#[derive(Copy, Clone, Default, PartialEq, Eq)]
 204pub enum ZetaEditPredictionModel {
 205    #[default]
 206    Zeta1,
 207    Zeta2,
 208    Sweep,
 209}
 210
 211#[derive(Debug, Clone, PartialEq)]
 212pub struct ZetaOptions {
 213    pub context: ContextMode,
 214    pub max_prompt_bytes: usize,
 215    pub max_diagnostic_bytes: usize,
 216    pub prompt_format: predict_edits_v3::PromptFormat,
 217    pub file_indexing_parallelism: usize,
 218    pub buffer_change_grouping_interval: Duration,
 219}
 220
 221#[derive(Debug, Clone, PartialEq)]
 222pub enum ContextMode {
 223    Agentic(AgenticContextOptions),
 224    Syntax(EditPredictionContextOptions),
 225}
 226
 227#[derive(Debug, Clone, PartialEq)]
 228pub struct AgenticContextOptions {
 229    pub excerpt: EditPredictionExcerptOptions,
 230}
 231
 232impl ContextMode {
 233    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 234        match self {
 235            ContextMode::Agentic(options) => &options.excerpt,
 236            ContextMode::Syntax(options) => &options.excerpt,
 237        }
 238    }
 239}
 240
 241#[derive(Debug)]
 242pub enum ZetaDebugInfo {
 243    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 244    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 245    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 246    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 247    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 248}
 249
 250#[derive(Debug)]
 251pub struct ZetaContextRetrievalStartedDebugInfo {
 252    pub project: Entity<Project>,
 253    pub timestamp: Instant,
 254    pub search_prompt: String,
 255}
 256
 257#[derive(Debug)]
 258pub struct ZetaContextRetrievalDebugInfo {
 259    pub project: Entity<Project>,
 260    pub timestamp: Instant,
 261}
 262
 263#[derive(Debug)]
 264pub struct ZetaEditPredictionDebugInfo {
 265    pub inputs: EditPredictionInputs,
 266    pub retrieval_time: Duration,
 267    pub buffer: WeakEntity<Buffer>,
 268    pub position: language::Anchor,
 269    pub local_prompt: Result<String, String>,
 270    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 271}
 272
 273#[derive(Debug)]
 274pub struct ZetaSearchQueryDebugInfo {
 275    pub project: Entity<Project>,
 276    pub timestamp: Instant,
 277    pub search_queries: Vec<SearchToolQuery>,
 278}
 279
 280pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 281
 282struct ZetaProject {
 283    syntax_index: Option<Entity<SyntaxIndex>>,
 284    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 285    last_event: Option<LastEvent>,
 286    recent_paths: VecDeque<ProjectPath>,
 287    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 288    current_prediction: Option<CurrentEditPrediction>,
 289    next_pending_prediction_id: usize,
 290    pending_predictions: ArrayVec<PendingPrediction, 2>,
 291    last_prediction_refresh: Option<(EntityId, Instant)>,
 292    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 293    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 294    refresh_context_debounce_task: Option<Task<Option<()>>>,
 295    refresh_context_timestamp: Option<Instant>,
 296    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 297    _subscription: gpui::Subscription,
 298}
 299
 300impl ZetaProject {
 301    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 302        self.events
 303            .iter()
 304            .cloned()
 305            .chain(
 306                self.last_event
 307                    .as_ref()
 308                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 309            )
 310            .collect()
 311    }
 312}
 313
 314#[derive(Debug, Clone)]
 315struct CurrentEditPrediction {
 316    pub requested_by: PredictionRequestedBy,
 317    pub prediction: EditPrediction,
 318    pub was_shown: bool,
 319}
 320
 321impl CurrentEditPrediction {
 322    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 323        let Some(new_edits) = self
 324            .prediction
 325            .interpolate(&self.prediction.buffer.read(cx))
 326        else {
 327            return false;
 328        };
 329
 330        if self.prediction.buffer != old_prediction.prediction.buffer {
 331            return true;
 332        }
 333
 334        let Some(old_edits) = old_prediction
 335            .prediction
 336            .interpolate(&old_prediction.prediction.buffer.read(cx))
 337        else {
 338            return true;
 339        };
 340
 341        let requested_by_buffer_id = self.requested_by.buffer_id();
 342
 343        // This reduces the occurrence of UI thrash from replacing edits
 344        //
 345        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 346        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 347            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 348            && old_edits.len() == 1
 349            && new_edits.len() == 1
 350        {
 351            let (old_range, old_text) = &old_edits[0];
 352            let (new_range, new_text) = &new_edits[0];
 353            new_range == old_range && new_text.starts_with(old_text.as_ref())
 354        } else {
 355            true
 356        }
 357    }
 358}
 359
 360#[derive(Debug, Clone)]
 361enum PredictionRequestedBy {
 362    DiagnosticsUpdate,
 363    Buffer(EntityId),
 364}
 365
 366impl PredictionRequestedBy {
 367    pub fn buffer_id(&self) -> Option<EntityId> {
 368        match self {
 369            PredictionRequestedBy::DiagnosticsUpdate => None,
 370            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 371        }
 372    }
 373}
 374
 375struct PendingPrediction {
 376    id: usize,
 377    task: Task<Option<EditPredictionId>>,
 378}
 379
 380/// A prediction from the perspective of a buffer.
 381#[derive(Debug)]
 382enum BufferEditPrediction<'a> {
 383    Local { prediction: &'a EditPrediction },
 384    Jump { prediction: &'a EditPrediction },
 385}
 386
 387struct RegisteredBuffer {
 388    snapshot: BufferSnapshot,
 389    _subscriptions: [gpui::Subscription; 2],
 390}
 391
 392struct LastEvent {
 393    old_snapshot: BufferSnapshot,
 394    new_snapshot: BufferSnapshot,
 395    end_edit_anchor: Option<Anchor>,
 396}
 397
 398impl LastEvent {
 399    pub fn finalize(
 400        &self,
 401        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 402        cx: &App,
 403    ) -> Option<Arc<predict_edits_v3::Event>> {
 404        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 405        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 406
 407        let file = self.new_snapshot.file();
 408        let old_file = self.old_snapshot.file();
 409
 410        let in_open_source_repo = [file, old_file].iter().all(|file| {
 411            file.is_some_and(|file| {
 412                license_detection_watchers
 413                    .get(&file.worktree_id(cx))
 414                    .is_some_and(|watcher| watcher.is_project_open_source())
 415            })
 416        });
 417
 418        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 419
 420        if path == old_path && diff.is_empty() {
 421            None
 422        } else {
 423            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 424                old_path,
 425                path,
 426                diff,
 427                in_open_source_repo,
 428                // TODO: Actually detect if this edit was predicted or not
 429                predicted: false,
 430            }))
 431        }
 432    }
 433}
 434
 435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 436    if let Some(file) = snapshot.file() {
 437        file.full_path(cx).into()
 438    } else {
 439        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 440    }
 441}
 442
 443impl Zeta {
 444    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 445        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 446    }
 447
 448    pub fn global(
 449        client: &Arc<Client>,
 450        user_store: &Entity<UserStore>,
 451        cx: &mut App,
 452    ) -> Entity<Self> {
 453        cx.try_global::<ZetaGlobal>()
 454            .map(|global| global.0.clone())
 455            .unwrap_or_else(|| {
 456                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 457                cx.set_global(ZetaGlobal(zeta.clone()));
 458                zeta
 459            })
 460    }
 461
 462    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 463        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 464        let data_collection_choice = Self::load_data_collection_choice();
 465
 466        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 467        cx.spawn(async move |this, cx| {
 468            while let Some(()) = reject_rx.next().await {
 469                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
 470                    .await
 471                    .log_err();
 472            }
 473            anyhow::Ok(())
 474        })
 475        .detach();
 476
 477        Self {
 478            projects: HashMap::default(),
 479            client,
 480            user_store,
 481            options: DEFAULT_OPTIONS,
 482            llm_token: LlmApiToken::default(),
 483            _llm_token_subscription: cx.subscribe(
 484                &refresh_llm_token_listener,
 485                |this, _listener, _event, cx| {
 486                    let client = this.client.clone();
 487                    let llm_token = this.llm_token.clone();
 488                    cx.spawn(async move |_this, _cx| {
 489                        llm_token.refresh(&client).await?;
 490                        anyhow::Ok(())
 491                    })
 492                    .detach_and_log_err(cx);
 493                },
 494            ),
 495            update_required: false,
 496            debug_tx: None,
 497            #[cfg(feature = "eval-support")]
 498            eval_cache: None,
 499            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 500            sweep_ai: SweepAi::new(cx),
 501            data_collection_choice,
 502            rejected_predictions: Vec::new(),
 503            reject_predictions_debounce_task: None,
 504            reject_predictions_tx: reject_tx,
 505            rated_predictions: Default::default(),
 506            shown_predictions: Default::default(),
 507        }
 508    }
 509
 510    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 511        self.edit_prediction_model = model;
 512    }
 513
 514    pub fn has_sweep_api_token(&self) -> bool {
 515        self.sweep_ai.api_token.is_some()
 516    }
 517
 518    #[cfg(feature = "eval-support")]
 519    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 520        self.eval_cache = Some(cache);
 521    }
 522
 523    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 524        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 525        self.debug_tx = Some(debug_watch_tx);
 526        debug_watch_rx
 527    }
 528
 529    pub fn options(&self) -> &ZetaOptions {
 530        &self.options
 531    }
 532
 533    pub fn set_options(&mut self, options: ZetaOptions) {
 534        self.options = options;
 535    }
 536
 537    pub fn clear_history(&mut self) {
 538        for zeta_project in self.projects.values_mut() {
 539            zeta_project.events.clear();
 540        }
 541    }
 542
 543    pub fn context_for_project(
 544        &self,
 545        project: &Entity<Project>,
 546    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 547        self.projects
 548            .get(&project.entity_id())
 549            .and_then(|project| {
 550                Some(
 551                    project
 552                        .context
 553                        .as_ref()?
 554                        .iter()
 555                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 556                )
 557            })
 558            .into_iter()
 559            .flatten()
 560    }
 561
 562    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 563        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 564            self.user_store.read(cx).edit_prediction_usage()
 565        } else {
 566            None
 567        }
 568    }
 569
 570    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 571        self.get_or_init_zeta_project(project, cx);
 572    }
 573
 574    pub fn register_buffer(
 575        &mut self,
 576        buffer: &Entity<Buffer>,
 577        project: &Entity<Project>,
 578        cx: &mut Context<Self>,
 579    ) {
 580        let zeta_project = self.get_or_init_zeta_project(project, cx);
 581        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 582    }
 583
 584    fn get_or_init_zeta_project(
 585        &mut self,
 586        project: &Entity<Project>,
 587        cx: &mut Context<Self>,
 588    ) -> &mut ZetaProject {
 589        self.projects
 590            .entry(project.entity_id())
 591            .or_insert_with(|| ZetaProject {
 592                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 593                    Some(cx.new(|cx| {
 594                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 595                    }))
 596                } else {
 597                    None
 598                },
 599                events: VecDeque::new(),
 600                last_event: None,
 601                recent_paths: VecDeque::new(),
 602                registered_buffers: HashMap::default(),
 603                current_prediction: None,
 604                pending_predictions: ArrayVec::new(),
 605                next_pending_prediction_id: 0,
 606                last_prediction_refresh: None,
 607                context: None,
 608                refresh_context_task: None,
 609                refresh_context_debounce_task: None,
 610                refresh_context_timestamp: None,
 611                license_detection_watchers: HashMap::default(),
 612                _subscription: cx.subscribe(&project, Self::handle_project_event),
 613            })
 614    }
 615
 616    fn handle_project_event(
 617        &mut self,
 618        project: Entity<Project>,
 619        event: &project::Event,
 620        cx: &mut Context<Self>,
 621    ) {
 622        // TODO [zeta2] init with recent paths
 623        match event {
 624            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 625                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 626                    return;
 627                };
 628                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 629                if let Some(path) = path {
 630                    if let Some(ix) = zeta_project
 631                        .recent_paths
 632                        .iter()
 633                        .position(|probe| probe == &path)
 634                    {
 635                        zeta_project.recent_paths.remove(ix);
 636                    }
 637                    zeta_project.recent_paths.push_front(path);
 638                }
 639            }
 640            project::Event::DiagnosticsUpdated { .. } => {
 641                if cx.has_flag::<Zeta2FeatureFlag>() {
 642                    self.refresh_prediction_from_diagnostics(project, cx);
 643                }
 644            }
 645            _ => (),
 646        }
 647    }
 648
 649    fn register_buffer_impl<'a>(
 650        zeta_project: &'a mut ZetaProject,
 651        buffer: &Entity<Buffer>,
 652        project: &Entity<Project>,
 653        cx: &mut Context<Self>,
 654    ) -> &'a mut RegisteredBuffer {
 655        let buffer_id = buffer.entity_id();
 656
 657        if let Some(file) = buffer.read(cx).file() {
 658            let worktree_id = file.worktree_id(cx);
 659            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 660                zeta_project
 661                    .license_detection_watchers
 662                    .entry(worktree_id)
 663                    .or_insert_with(|| {
 664                        let project_entity_id = project.entity_id();
 665                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 666                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 667                            else {
 668                                return;
 669                            };
 670                            zeta_project.license_detection_watchers.remove(&worktree_id);
 671                        })
 672                        .detach();
 673                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 674                    });
 675            }
 676        }
 677
 678        match zeta_project.registered_buffers.entry(buffer_id) {
 679            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 680            hash_map::Entry::Vacant(entry) => {
 681                let snapshot = buffer.read(cx).snapshot();
 682                let project_entity_id = project.entity_id();
 683                entry.insert(RegisteredBuffer {
 684                    snapshot,
 685                    _subscriptions: [
 686                        cx.subscribe(buffer, {
 687                            let project = project.downgrade();
 688                            move |this, buffer, event, cx| {
 689                                if let language::BufferEvent::Edited = event
 690                                    && let Some(project) = project.upgrade()
 691                                {
 692                                    this.report_changes_for_buffer(&buffer, &project, cx);
 693                                }
 694                            }
 695                        }),
 696                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 697                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 698                            else {
 699                                return;
 700                            };
 701                            zeta_project.registered_buffers.remove(&buffer_id);
 702                        }),
 703                    ],
 704                })
 705            }
 706        }
 707    }
 708
 709    fn report_changes_for_buffer(
 710        &mut self,
 711        buffer: &Entity<Buffer>,
 712        project: &Entity<Project>,
 713        cx: &mut Context<Self>,
 714    ) {
 715        let project_state = self.get_or_init_zeta_project(project, cx);
 716        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 717
 718        let new_snapshot = buffer.read(cx).snapshot();
 719        if new_snapshot.version == registered_buffer.snapshot.version {
 720            return;
 721        }
 722
 723        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 724        let end_edit_anchor = new_snapshot
 725            .anchored_edits_since::<Point>(&old_snapshot.version)
 726            .last()
 727            .map(|(_, range)| range.end);
 728        let events = &mut project_state.events;
 729
 730        if let Some(LastEvent {
 731            new_snapshot: last_new_snapshot,
 732            end_edit_anchor: last_end_edit_anchor,
 733            ..
 734        }) = project_state.last_event.as_mut()
 735        {
 736            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 737                == last_new_snapshot.remote_id()
 738                && old_snapshot.version == last_new_snapshot.version;
 739
 740            let should_coalesce = is_next_snapshot_of_same_buffer
 741                && end_edit_anchor
 742                    .as_ref()
 743                    .zip(last_end_edit_anchor.as_ref())
 744                    .is_some_and(|(a, b)| {
 745                        let a = a.to_point(&new_snapshot);
 746                        let b = b.to_point(&new_snapshot);
 747                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 748                    });
 749
 750            if should_coalesce {
 751                *last_end_edit_anchor = end_edit_anchor;
 752                *last_new_snapshot = new_snapshot;
 753                return;
 754            }
 755        }
 756
 757        if events.len() + 1 >= EVENT_COUNT_MAX {
 758            events.pop_front();
 759        }
 760
 761        if let Some(event) = project_state.last_event.take() {
 762            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 763        }
 764
 765        project_state.last_event = Some(LastEvent {
 766            old_snapshot,
 767            new_snapshot,
 768            end_edit_anchor,
 769        });
 770    }
 771
 772    fn current_prediction_for_buffer(
 773        &self,
 774        buffer: &Entity<Buffer>,
 775        project: &Entity<Project>,
 776        cx: &App,
 777    ) -> Option<BufferEditPrediction<'_>> {
 778        let project_state = self.projects.get(&project.entity_id())?;
 779
 780        let CurrentEditPrediction {
 781            requested_by,
 782            prediction,
 783            ..
 784        } = project_state.current_prediction.as_ref()?;
 785
 786        if prediction.targets_buffer(buffer.read(cx)) {
 787            Some(BufferEditPrediction::Local { prediction })
 788        } else {
 789            let show_jump = match requested_by {
 790                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 791                    requested_by_buffer_id == &buffer.entity_id()
 792                }
 793                PredictionRequestedBy::DiagnosticsUpdate => true,
 794            };
 795
 796            if show_jump {
 797                Some(BufferEditPrediction::Jump { prediction })
 798            } else {
 799                None
 800            }
 801        }
 802    }
 803
 804    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 805        match self.edit_prediction_model {
 806            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 807            ZetaEditPredictionModel::Sweep => return,
 808        }
 809
 810        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 811            return;
 812        };
 813
 814        let Some(prediction) = project_state.current_prediction.take() else {
 815            return;
 816        };
 817        let request_id = prediction.prediction.id.to_string();
 818        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 819            self.cancel_pending_prediction(pending_prediction, cx);
 820        }
 821
 822        let client = self.client.clone();
 823        let llm_token = self.llm_token.clone();
 824        let app_version = AppVersion::global(cx);
 825        cx.spawn(async move |this, cx| {
 826            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 827                http_client::Url::parse(&predict_edits_url)?
 828            } else {
 829                client
 830                    .http_client()
 831                    .build_zed_llm_url("/predict_edits/accept", &[])?
 832            };
 833
 834            let response = cx
 835                .background_spawn(Self::send_api_request::<()>(
 836                    move |builder| {
 837                        let req = builder.uri(url.as_ref()).body(
 838                            serde_json::to_string(&AcceptEditPredictionBody {
 839                                request_id: request_id.clone(),
 840                            })?
 841                            .into(),
 842                        );
 843                        Ok(req?)
 844                    },
 845                    client,
 846                    llm_token,
 847                    app_version,
 848                ))
 849                .await;
 850
 851            Self::handle_api_response(&this, response, cx)?;
 852            anyhow::Ok(())
 853        })
 854        .detach_and_log_err(cx);
 855    }
 856
 857    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 858        match self.edit_prediction_model {
 859            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 860            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 861        }
 862
 863        let client = self.client.clone();
 864        let llm_token = self.llm_token.clone();
 865        let app_version = AppVersion::global(cx);
 866        let last_rejection = self.rejected_predictions.last().cloned();
 867        let Some(last_rejection) = last_rejection else {
 868            return Task::ready(anyhow::Ok(()));
 869        };
 870
 871        let body = serde_json::to_string(&RejectEditPredictionsBody {
 872            rejections: self.rejected_predictions.clone(),
 873        })
 874        .ok();
 875
 876        cx.spawn(async move |this, cx| {
 877            let url = client
 878                .http_client()
 879                .build_zed_llm_url("/predict_edits/reject", &[])?;
 880
 881            cx.background_spawn(Self::send_api_request::<()>(
 882                move |builder| {
 883                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 884                    Ok(req?)
 885                },
 886                client,
 887                llm_token,
 888                app_version,
 889            ))
 890            .await
 891            .context("Failed to reject edit predictions")?;
 892
 893            this.update(cx, |this, _| {
 894                if let Some(ix) = this
 895                    .rejected_predictions
 896                    .iter()
 897                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 898                {
 899                    this.rejected_predictions.drain(..ix + 1);
 900                }
 901            })
 902        })
 903    }
 904
 905    fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 906        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 907            project_state.pending_predictions.clear();
 908            if let Some(prediction) = project_state.current_prediction.take() {
 909                self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
 910            }
 911        };
 912    }
 913
 914    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 915        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 916            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 917                if !current_prediction.was_shown {
 918                    current_prediction.was_shown = true;
 919                    self.shown_predictions
 920                        .push_front(current_prediction.prediction.clone());
 921                    if self.shown_predictions.len() > 50 {
 922                        let completion = self.shown_predictions.pop_back().unwrap();
 923                        self.rated_predictions.remove(&completion.id);
 924                    }
 925                }
 926            }
 927        }
 928    }
 929
 930    fn discard_prediction(
 931        &mut self,
 932        prediction_id: EditPredictionId,
 933        was_shown: bool,
 934        cx: &mut Context<Self>,
 935    ) {
 936        self.rejected_predictions.push(EditPredictionRejection {
 937            request_id: prediction_id.to_string(),
 938            was_shown,
 939        });
 940
 941        let reached_request_limit =
 942            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
 943        let reject_tx = self.reject_predictions_tx.clone();
 944        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 945            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
 946            if !reached_request_limit {
 947                cx.background_executor()
 948                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
 949                    .await;
 950            }
 951            reject_tx.unbounded_send(()).log_err();
 952        }));
 953    }
 954
 955    fn cancel_pending_prediction(
 956        &self,
 957        pending_prediction: PendingPrediction,
 958        cx: &mut Context<Self>,
 959    ) {
 960        cx.spawn(async move |this, cx| {
 961            let Some(prediction_id) = pending_prediction.task.await else {
 962                return;
 963            };
 964
 965            this.update(cx, |this, cx| {
 966                this.discard_prediction(prediction_id, false, cx);
 967            })
 968            .ok();
 969        })
 970        .detach()
 971    }
 972
 973    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
 974        self.projects
 975            .get(&project.entity_id())
 976            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
 977    }
 978
 979    pub fn refresh_prediction_from_buffer(
 980        &mut self,
 981        project: Entity<Project>,
 982        buffer: Entity<Buffer>,
 983        position: language::Anchor,
 984        cx: &mut Context<Self>,
 985    ) {
 986        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
 987            let Some(request_task) = this
 988                .update(cx, |this, cx| {
 989                    this.request_prediction(&project, &buffer, position, cx)
 990                })
 991                .log_err()
 992            else {
 993                return Task::ready(anyhow::Ok(None));
 994            };
 995
 996            let project = project.clone();
 997            cx.spawn(async move |cx| {
 998                if let Some(prediction) = request_task.await? {
 999                    let id = prediction.id.clone();
1000                    this.update(cx, |this, cx| {
1001                        let project_state = this
1002                            .projects
1003                            .get_mut(&project.entity_id())
1004                            .context("Project not found")?;
1005
1006                        let new_prediction = CurrentEditPrediction {
1007                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
1008                            prediction: prediction,
1009                            was_shown: false,
1010                        };
1011
1012                        if project_state
1013                            .current_prediction
1014                            .as_ref()
1015                            .is_none_or(|old_prediction| {
1016                                new_prediction.should_replace_prediction(&old_prediction, cx)
1017                            })
1018                        {
1019                            project_state.current_prediction = Some(new_prediction);
1020                            cx.notify();
1021                        }
1022                        anyhow::Ok(())
1023                    })??;
1024                    Ok(Some(id))
1025                } else {
1026                    Ok(None)
1027                }
1028            })
1029        })
1030    }
1031
1032    pub fn refresh_prediction_from_diagnostics(
1033        &mut self,
1034        project: Entity<Project>,
1035        cx: &mut Context<Self>,
1036    ) {
1037        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1038            return;
1039        };
1040
1041        // Prefer predictions from buffer
1042        if zeta_project.current_prediction.is_some() {
1043            return;
1044        };
1045
1046        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1047            let Some(open_buffer_task) = project
1048                .update(cx, |project, cx| {
1049                    project
1050                        .active_entry()
1051                        .and_then(|entry| project.path_for_entry(entry, cx))
1052                        .map(|path| project.open_buffer(path, cx))
1053                })
1054                .log_err()
1055                .flatten()
1056            else {
1057                return Task::ready(anyhow::Ok(None));
1058            };
1059
1060            cx.spawn(async move |cx| {
1061                let active_buffer = open_buffer_task.await?;
1062                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1063
1064                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1065                    active_buffer,
1066                    &snapshot,
1067                    Default::default(),
1068                    Default::default(),
1069                    &project,
1070                    cx,
1071                )
1072                .await?
1073                else {
1074                    return anyhow::Ok(None);
1075                };
1076
1077                let Some(prediction) = this
1078                    .update(cx, |this, cx| {
1079                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
1080                    })?
1081                    .await?
1082                else {
1083                    return anyhow::Ok(None);
1084                };
1085
1086                let id = prediction.id.clone();
1087                this.update(cx, |this, cx| {
1088                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1089                        zeta_project.current_prediction.get_or_insert_with(|| {
1090                            cx.notify();
1091                            CurrentEditPrediction {
1092                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
1093                                prediction,
1094                                was_shown: false,
1095                            }
1096                        });
1097                    }
1098                })?;
1099
1100                anyhow::Ok(Some(id))
1101            })
1102        });
1103    }
1104
1105    #[cfg(not(test))]
1106    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1107    #[cfg(test)]
1108    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1109
1110    fn queue_prediction_refresh(
1111        &mut self,
1112        project: Entity<Project>,
1113        throttle_entity: EntityId,
1114        cx: &mut Context<Self>,
1115        do_refresh: impl FnOnce(
1116            WeakEntity<Self>,
1117            &mut AsyncApp,
1118        ) -> Task<Result<Option<EditPredictionId>>>
1119        + 'static,
1120    ) {
1121        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1122        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1123        zeta_project.next_pending_prediction_id += 1;
1124        let last_request = zeta_project.last_prediction_refresh;
1125
1126        let task = cx.spawn(async move |this, cx| {
1127            if let Some((last_entity, last_timestamp)) = last_request
1128                && throttle_entity == last_entity
1129                && let Some(timeout) =
1130                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1131            {
1132                cx.background_executor().timer(timeout).await;
1133            }
1134
1135            this.update(cx, |this, cx| {
1136                this.get_or_init_zeta_project(&project, cx)
1137                    .last_prediction_refresh = Some((throttle_entity, Instant::now()));
1138            })
1139            .ok();
1140
1141            let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
1142
1143            // When a prediction completes, remove it from the pending list, and cancel
1144            // any pending predictions that were enqueued before it.
1145            this.update(cx, |this, cx| {
1146                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1147                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1148                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1149                    if pending_prediction.id == pending_prediction_id {
1150                        pending_predictions.remove(ix);
1151                        for pending_prediction in pending_predictions.drain(0..ix) {
1152                            this.cancel_pending_prediction(pending_prediction, cx)
1153                        }
1154                        break;
1155                    }
1156                }
1157                this.get_or_init_zeta_project(&project, cx)
1158                    .pending_predictions = pending_predictions;
1159                cx.notify();
1160            })
1161            .ok();
1162
1163            edit_prediction_id
1164        });
1165
1166        if zeta_project.pending_predictions.len() <= 1 {
1167            zeta_project.pending_predictions.push(PendingPrediction {
1168                id: pending_prediction_id,
1169                task,
1170            });
1171        } else if zeta_project.pending_predictions.len() == 2 {
1172            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1173            zeta_project.pending_predictions.push(PendingPrediction {
1174                id: pending_prediction_id,
1175                task,
1176            });
1177            self.cancel_pending_prediction(pending_prediction, cx);
1178        }
1179    }
1180
1181    pub fn request_prediction(
1182        &mut self,
1183        project: &Entity<Project>,
1184        active_buffer: &Entity<Buffer>,
1185        position: language::Anchor,
1186        cx: &mut Context<Self>,
1187    ) -> Task<Result<Option<EditPrediction>>> {
1188        self.request_prediction_internal(
1189            project.clone(),
1190            active_buffer.clone(),
1191            position,
1192            cx.has_flag::<Zeta2FeatureFlag>(),
1193            cx,
1194        )
1195    }
1196
1197    fn request_prediction_internal(
1198        &mut self,
1199        project: Entity<Project>,
1200        active_buffer: Entity<Buffer>,
1201        position: language::Anchor,
1202        allow_jump: bool,
1203        cx: &mut Context<Self>,
1204    ) -> Task<Result<Option<EditPrediction>>> {
1205        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1206
1207        self.get_or_init_zeta_project(&project, cx);
1208        let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1209        let events = zeta_project.events(cx);
1210        let has_events = !events.is_empty();
1211
1212        let snapshot = active_buffer.read(cx).snapshot();
1213        let cursor_point = position.to_point(&snapshot);
1214        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1215        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1216        let diagnostic_search_range =
1217            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1218
1219        let task = match self.edit_prediction_model {
1220            ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1221                self,
1222                &project,
1223                &active_buffer,
1224                snapshot.clone(),
1225                position,
1226                events,
1227                cx,
1228            ),
1229            ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1230                &project,
1231                &active_buffer,
1232                snapshot.clone(),
1233                position,
1234                events,
1235                cx,
1236            ),
1237            ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1238                &project,
1239                &active_buffer,
1240                snapshot.clone(),
1241                position,
1242                events,
1243                &zeta_project.recent_paths,
1244                diagnostic_search_range.clone(),
1245                cx,
1246            ),
1247        };
1248
1249        cx.spawn(async move |this, cx| {
1250            let prediction = task
1251                .await?
1252                .filter(|prediction| !prediction.edits.is_empty());
1253
1254            if prediction.is_none() && allow_jump {
1255                let cursor_point = position.to_point(&snapshot);
1256                if has_events
1257                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1258                        active_buffer.clone(),
1259                        &snapshot,
1260                        diagnostic_search_range,
1261                        cursor_point,
1262                        &project,
1263                        cx,
1264                    )
1265                    .await?
1266                {
1267                    return this
1268                        .update(cx, |this, cx| {
1269                            this.request_prediction_internal(
1270                                project,
1271                                jump_buffer,
1272                                jump_position,
1273                                false,
1274                                cx,
1275                            )
1276                        })?
1277                        .await;
1278                }
1279
1280                return anyhow::Ok(None);
1281            }
1282
1283            Ok(prediction)
1284        })
1285    }
1286
1287    async fn next_diagnostic_location(
1288        active_buffer: Entity<Buffer>,
1289        active_buffer_snapshot: &BufferSnapshot,
1290        active_buffer_diagnostic_search_range: Range<Point>,
1291        active_buffer_cursor_point: Point,
1292        project: &Entity<Project>,
1293        cx: &mut AsyncApp,
1294    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1295        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1296        let mut jump_location = active_buffer_snapshot
1297            .diagnostic_groups(None)
1298            .into_iter()
1299            .filter_map(|(_, group)| {
1300                let range = &group.entries[group.primary_ix]
1301                    .range
1302                    .to_point(&active_buffer_snapshot);
1303                if range.overlaps(&active_buffer_diagnostic_search_range) {
1304                    None
1305                } else {
1306                    Some(range.start)
1307                }
1308            })
1309            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1310            .map(|position| {
1311                (
1312                    active_buffer.clone(),
1313                    active_buffer_snapshot.anchor_before(position),
1314                )
1315            });
1316
1317        if jump_location.is_none() {
1318            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1319                let file = buffer.file()?;
1320
1321                Some(ProjectPath {
1322                    worktree_id: file.worktree_id(cx),
1323                    path: file.path().clone(),
1324                })
1325            })?;
1326
1327            let buffer_task = project.update(cx, |project, cx| {
1328                let (path, _, _) = project
1329                    .diagnostic_summaries(false, cx)
1330                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1331                    .max_by_key(|(path, _, _)| {
1332                        // find the buffer with errors that shares most parent directories
1333                        path.path
1334                            .components()
1335                            .zip(
1336                                active_buffer_path
1337                                    .as_ref()
1338                                    .map(|p| p.path.components())
1339                                    .unwrap_or_default(),
1340                            )
1341                            .take_while(|(a, b)| a == b)
1342                            .count()
1343                    })?;
1344
1345                Some(project.open_buffer(path, cx))
1346            })?;
1347
1348            if let Some(buffer_task) = buffer_task {
1349                let closest_buffer = buffer_task.await?;
1350
1351                jump_location = closest_buffer
1352                    .read_with(cx, |buffer, _cx| {
1353                        buffer
1354                            .buffer_diagnostics(None)
1355                            .into_iter()
1356                            .min_by_key(|entry| entry.diagnostic.severity)
1357                            .map(|entry| entry.range.start)
1358                    })?
1359                    .map(|position| (closest_buffer, position));
1360            }
1361        }
1362
1363        anyhow::Ok(jump_location)
1364    }
1365
1366    fn request_prediction_with_zeta2(
1367        &mut self,
1368        project: &Entity<Project>,
1369        active_buffer: &Entity<Buffer>,
1370        active_snapshot: BufferSnapshot,
1371        position: language::Anchor,
1372        events: Vec<Arc<Event>>,
1373        cx: &mut Context<Self>,
1374    ) -> Task<Result<Option<EditPrediction>>> {
1375        let project_state = self.projects.get(&project.entity_id());
1376
1377        let index_state = project_state.and_then(|state| {
1378            state
1379                .syntax_index
1380                .as_ref()
1381                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1382        });
1383        let options = self.options.clone();
1384        let buffer_snapshotted_at = Instant::now();
1385        let Some(excerpt_path) = active_snapshot
1386            .file()
1387            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1388        else {
1389            return Task::ready(Err(anyhow!("No file path for excerpt")));
1390        };
1391        let client = self.client.clone();
1392        let llm_token = self.llm_token.clone();
1393        let app_version = AppVersion::global(cx);
1394        let worktree_snapshots = project
1395            .read(cx)
1396            .worktrees(cx)
1397            .map(|worktree| worktree.read(cx).snapshot())
1398            .collect::<Vec<_>>();
1399        let debug_tx = self.debug_tx.clone();
1400
1401        let diagnostics = active_snapshot.diagnostic_sets().clone();
1402
1403        let file = active_buffer.read(cx).file();
1404        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1405            let mut path = f.worktree.read(cx).absolutize(&f.path);
1406            if path.pop() { Some(path) } else { None }
1407        });
1408
1409        // TODO data collection
1410        let can_collect_data = file
1411            .as_ref()
1412            .map_or(false, |file| self.can_collect_file(project, file, cx));
1413
1414        let empty_context_files = HashMap::default();
1415        let context_files = project_state
1416            .and_then(|project_state| project_state.context.as_ref())
1417            .unwrap_or(&empty_context_files);
1418
1419        #[cfg(feature = "eval-support")]
1420        let parsed_fut = futures::future::join_all(
1421            context_files
1422                .keys()
1423                .map(|buffer| buffer.read(cx).parsing_idle()),
1424        );
1425
1426        let mut included_files = context_files
1427            .iter()
1428            .filter_map(|(buffer_entity, ranges)| {
1429                let buffer = buffer_entity.read(cx);
1430                Some((
1431                    buffer_entity.clone(),
1432                    buffer.snapshot(),
1433                    buffer.file()?.full_path(cx).into(),
1434                    ranges.clone(),
1435                ))
1436            })
1437            .collect::<Vec<_>>();
1438
1439        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1440            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1441        });
1442
1443        #[cfg(feature = "eval-support")]
1444        let eval_cache = self.eval_cache.clone();
1445
1446        let request_task = cx.background_spawn({
1447            let active_buffer = active_buffer.clone();
1448            async move {
1449                #[cfg(feature = "eval-support")]
1450                parsed_fut.await;
1451
1452                let index_state = if let Some(index_state) = index_state {
1453                    Some(index_state.lock_owned().await)
1454                } else {
1455                    None
1456                };
1457
1458                let cursor_offset = position.to_offset(&active_snapshot);
1459                let cursor_point = cursor_offset.to_point(&active_snapshot);
1460
1461                let before_retrieval = Instant::now();
1462
1463                let (diagnostic_groups, diagnostic_groups_truncated) =
1464                    Self::gather_nearby_diagnostics(
1465                        cursor_offset,
1466                        &diagnostics,
1467                        &active_snapshot,
1468                        options.max_diagnostic_bytes,
1469                    );
1470
1471                let cloud_request = match options.context {
1472                    ContextMode::Agentic(context_options) => {
1473                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1474                            cursor_point,
1475                            &active_snapshot,
1476                            &context_options.excerpt,
1477                            index_state.as_deref(),
1478                        ) else {
1479                            return Ok((None, None));
1480                        };
1481
1482                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1483                            ..active_snapshot.anchor_before(excerpt.range.end);
1484
1485                        if let Some(buffer_ix) =
1486                            included_files.iter().position(|(_, snapshot, _, _)| {
1487                                snapshot.remote_id() == active_snapshot.remote_id()
1488                            })
1489                        {
1490                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1491                            ranges.push(excerpt_anchor_range);
1492                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1493                            let last_ix = included_files.len() - 1;
1494                            included_files.swap(buffer_ix, last_ix);
1495                        } else {
1496                            included_files.push((
1497                                active_buffer.clone(),
1498                                active_snapshot.clone(),
1499                                excerpt_path.clone(),
1500                                vec![excerpt_anchor_range],
1501                            ));
1502                        }
1503
1504                        let included_files = included_files
1505                            .iter()
1506                            .map(|(_, snapshot, path, ranges)| {
1507                                let ranges = ranges
1508                                    .iter()
1509                                    .map(|range| {
1510                                        let point_range = range.to_point(&snapshot);
1511                                        Line(point_range.start.row)..Line(point_range.end.row)
1512                                    })
1513                                    .collect::<Vec<_>>();
1514                                let excerpts = assemble_excerpts(&snapshot, ranges);
1515                                predict_edits_v3::IncludedFile {
1516                                    path: path.clone(),
1517                                    max_row: Line(snapshot.max_point().row),
1518                                    excerpts,
1519                                }
1520                            })
1521                            .collect::<Vec<_>>();
1522
1523                        predict_edits_v3::PredictEditsRequest {
1524                            excerpt_path,
1525                            excerpt: String::new(),
1526                            excerpt_line_range: Line(0)..Line(0),
1527                            excerpt_range: 0..0,
1528                            cursor_point: predict_edits_v3::Point {
1529                                line: predict_edits_v3::Line(cursor_point.row),
1530                                column: cursor_point.column,
1531                            },
1532                            included_files,
1533                            referenced_declarations: vec![],
1534                            events,
1535                            can_collect_data,
1536                            diagnostic_groups,
1537                            diagnostic_groups_truncated,
1538                            debug_info: debug_tx.is_some(),
1539                            prompt_max_bytes: Some(options.max_prompt_bytes),
1540                            prompt_format: options.prompt_format,
1541                            // TODO [zeta2]
1542                            signatures: vec![],
1543                            excerpt_parent: None,
1544                            git_info: None,
1545                        }
1546                    }
1547                    ContextMode::Syntax(context_options) => {
1548                        let Some(context) = EditPredictionContext::gather_context(
1549                            cursor_point,
1550                            &active_snapshot,
1551                            parent_abs_path.as_deref(),
1552                            &context_options,
1553                            index_state.as_deref(),
1554                        ) else {
1555                            return Ok((None, None));
1556                        };
1557
1558                        make_syntax_context_cloud_request(
1559                            excerpt_path,
1560                            context,
1561                            events,
1562                            can_collect_data,
1563                            diagnostic_groups,
1564                            diagnostic_groups_truncated,
1565                            None,
1566                            debug_tx.is_some(),
1567                            &worktree_snapshots,
1568                            index_state.as_deref(),
1569                            Some(options.max_prompt_bytes),
1570                            options.prompt_format,
1571                        )
1572                    }
1573                };
1574
1575                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1576
1577                let inputs = EditPredictionInputs {
1578                    included_files: cloud_request.included_files,
1579                    events: cloud_request.events,
1580                    cursor_point: cloud_request.cursor_point,
1581                    cursor_path: cloud_request.excerpt_path,
1582                };
1583
1584                let retrieval_time = Instant::now() - before_retrieval;
1585
1586                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1587                    let (response_tx, response_rx) = oneshot::channel();
1588
1589                    debug_tx
1590                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1591                            ZetaEditPredictionDebugInfo {
1592                                inputs: inputs.clone(),
1593                                retrieval_time,
1594                                buffer: active_buffer.downgrade(),
1595                                local_prompt: match prompt_result.as_ref() {
1596                                    Ok((prompt, _)) => Ok(prompt.clone()),
1597                                    Err(err) => Err(err.to_string()),
1598                                },
1599                                position,
1600                                response_rx,
1601                            },
1602                        ))
1603                        .ok();
1604                    Some(response_tx)
1605                } else {
1606                    None
1607                };
1608
1609                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1610                    if let Some(debug_response_tx) = debug_response_tx {
1611                        debug_response_tx
1612                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1613                            .ok();
1614                    }
1615                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1616                }
1617
1618                let (prompt, _) = prompt_result?;
1619                let generation_params =
1620                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1621                let request = open_ai::Request {
1622                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1623                    messages: vec![open_ai::RequestMessage::User {
1624                        content: open_ai::MessageContent::Plain(prompt),
1625                    }],
1626                    stream: false,
1627                    max_completion_tokens: None,
1628                    stop: generation_params.stop.unwrap_or_default(),
1629                    temperature: generation_params.temperature.unwrap_or(0.7),
1630                    tool_choice: None,
1631                    parallel_tool_calls: None,
1632                    tools: vec![],
1633                    prompt_cache_key: None,
1634                    reasoning_effort: None,
1635                };
1636
1637                log::trace!("Sending edit prediction request");
1638
1639                let before_request = Instant::now();
1640                let response = Self::send_raw_llm_request(
1641                    request,
1642                    client,
1643                    llm_token,
1644                    app_version,
1645                    #[cfg(feature = "eval-support")]
1646                    eval_cache,
1647                    #[cfg(feature = "eval-support")]
1648                    EvalCacheEntryKind::Prediction,
1649                )
1650                .await;
1651                let received_response_at = Instant::now();
1652                let request_time = received_response_at - before_request;
1653
1654                log::trace!("Got edit prediction response");
1655
1656                if let Some(debug_response_tx) = debug_response_tx {
1657                    debug_response_tx
1658                        .send((
1659                            response
1660                                .as_ref()
1661                                .map_err(|err| err.to_string())
1662                                .map(|response| response.0.clone()),
1663                            request_time,
1664                        ))
1665                        .ok();
1666                }
1667
1668                let (res, usage) = response?;
1669                let request_id = EditPredictionId(res.id.clone().into());
1670                let Some(mut output_text) = text_from_response(res) else {
1671                    return Ok((None, usage));
1672                };
1673
1674                if output_text.contains(CURSOR_MARKER) {
1675                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1676                    output_text = output_text.replace(CURSOR_MARKER, "");
1677                }
1678
1679                let get_buffer_from_context = |path: &Path| {
1680                    included_files
1681                        .iter()
1682                        .find_map(|(_, buffer, probe_path, ranges)| {
1683                            if probe_path.as_ref() == path {
1684                                Some((buffer, ranges.as_slice()))
1685                            } else {
1686                                None
1687                            }
1688                        })
1689                };
1690
1691                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1692                    PromptFormat::NumLinesUniDiff => {
1693                        // TODO: Implement parsing of multi-file diffs
1694                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1695                    }
1696                    PromptFormat::Minimal
1697                    | PromptFormat::MinimalQwen
1698                    | PromptFormat::SeedCoder1120 => {
1699                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1700                            let edits = vec![];
1701                            (&active_snapshot, edits)
1702                        } else {
1703                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1704                        }
1705                    }
1706                    PromptFormat::OldTextNewText => {
1707                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1708                            .await?
1709                    }
1710                    _ => {
1711                        bail!("unsupported prompt format {}", options.prompt_format)
1712                    }
1713                };
1714
1715                let edited_buffer = included_files
1716                    .iter()
1717                    .find_map(|(buffer, snapshot, _, _)| {
1718                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1719                            Some(buffer.clone())
1720                        } else {
1721                            None
1722                        }
1723                    })
1724                    .context("Failed to find buffer in included_buffers")?;
1725
1726                anyhow::Ok((
1727                    Some((
1728                        request_id,
1729                        inputs,
1730                        edited_buffer,
1731                        edited_buffer_snapshot.clone(),
1732                        edits,
1733                        received_response_at,
1734                    )),
1735                    usage,
1736                ))
1737            }
1738        });
1739
1740        cx.spawn({
1741            async move |this, cx| {
1742                let Some((
1743                    id,
1744                    inputs,
1745                    edited_buffer,
1746                    edited_buffer_snapshot,
1747                    edits,
1748                    received_response_at,
1749                )) = Self::handle_api_response(&this, request_task.await, cx)?
1750                else {
1751                    return Ok(None);
1752                };
1753
1754                // TODO telemetry: duration, etc
1755                Ok(EditPrediction::new(
1756                    id,
1757                    &edited_buffer,
1758                    &edited_buffer_snapshot,
1759                    edits.into(),
1760                    buffer_snapshotted_at,
1761                    received_response_at,
1762                    inputs,
1763                    cx,
1764                )
1765                .await)
1766            }
1767        })
1768    }
1769
1770    async fn send_raw_llm_request(
1771        request: open_ai::Request,
1772        client: Arc<Client>,
1773        llm_token: LlmApiToken,
1774        app_version: Version,
1775        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1776        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1777    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1778        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1779            http_client::Url::parse(&predict_edits_url)?
1780        } else {
1781            client
1782                .http_client()
1783                .build_zed_llm_url("/predict_edits/raw", &[])?
1784        };
1785
1786        #[cfg(feature = "eval-support")]
1787        let cache_key = if let Some(cache) = eval_cache {
1788            use collections::FxHasher;
1789            use std::hash::{Hash, Hasher};
1790
1791            let mut hasher = FxHasher::default();
1792            url.hash(&mut hasher);
1793            let request_str = serde_json::to_string_pretty(&request)?;
1794            request_str.hash(&mut hasher);
1795            let hash = hasher.finish();
1796
1797            let key = (eval_cache_kind, hash);
1798            if let Some(response_str) = cache.read(key) {
1799                return Ok((serde_json::from_str(&response_str)?, None));
1800            }
1801
1802            Some((cache, request_str, key))
1803        } else {
1804            None
1805        };
1806
1807        let (response, usage) = Self::send_api_request(
1808            |builder| {
1809                let req = builder
1810                    .uri(url.as_ref())
1811                    .body(serde_json::to_string(&request)?.into());
1812                Ok(req?)
1813            },
1814            client,
1815            llm_token,
1816            app_version,
1817        )
1818        .await?;
1819
1820        #[cfg(feature = "eval-support")]
1821        if let Some((cache, request, key)) = cache_key {
1822            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1823        }
1824
1825        Ok((response, usage))
1826    }
1827
1828    fn handle_api_response<T>(
1829        this: &WeakEntity<Self>,
1830        response: Result<(T, Option<EditPredictionUsage>)>,
1831        cx: &mut gpui::AsyncApp,
1832    ) -> Result<T> {
1833        match response {
1834            Ok((data, usage)) => {
1835                if let Some(usage) = usage {
1836                    this.update(cx, |this, cx| {
1837                        this.user_store.update(cx, |user_store, cx| {
1838                            user_store.update_edit_prediction_usage(usage, cx);
1839                        });
1840                    })
1841                    .ok();
1842                }
1843                Ok(data)
1844            }
1845            Err(err) => {
1846                if err.is::<ZedUpdateRequiredError>() {
1847                    cx.update(|cx| {
1848                        this.update(cx, |this, _cx| {
1849                            this.update_required = true;
1850                        })
1851                        .ok();
1852
1853                        let error_message: SharedString = err.to_string().into();
1854                        show_app_notification(
1855                            NotificationId::unique::<ZedUpdateRequiredError>(),
1856                            cx,
1857                            move |cx| {
1858                                cx.new(|cx| {
1859                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1860                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1861                                })
1862                            },
1863                        );
1864                    })
1865                    .ok();
1866                }
1867                Err(err)
1868            }
1869        }
1870    }
1871
1872    async fn send_api_request<Res>(
1873        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1874        client: Arc<Client>,
1875        llm_token: LlmApiToken,
1876        app_version: Version,
1877    ) -> Result<(Res, Option<EditPredictionUsage>)>
1878    where
1879        Res: DeserializeOwned,
1880    {
1881        let http_client = client.http_client();
1882        let mut token = llm_token.acquire(&client).await?;
1883        let mut did_retry = false;
1884
1885        loop {
1886            let request_builder = http_client::Request::builder().method(Method::POST);
1887
1888            let request = build(
1889                request_builder
1890                    .header("Content-Type", "application/json")
1891                    .header("Authorization", format!("Bearer {}", token))
1892                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1893            )?;
1894
1895            let mut response = http_client.send(request).await?;
1896
1897            if let Some(minimum_required_version) = response
1898                .headers()
1899                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1900                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1901            {
1902                anyhow::ensure!(
1903                    app_version >= minimum_required_version,
1904                    ZedUpdateRequiredError {
1905                        minimum_version: minimum_required_version
1906                    }
1907                );
1908            }
1909
1910            if response.status().is_success() {
1911                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1912
1913                let mut body = Vec::new();
1914                response.body_mut().read_to_end(&mut body).await?;
1915                return Ok((serde_json::from_slice(&body)?, usage));
1916            } else if !did_retry
1917                && response
1918                    .headers()
1919                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1920                    .is_some()
1921            {
1922                did_retry = true;
1923                token = llm_token.refresh(&client).await?;
1924            } else {
1925                let mut body = String::new();
1926                response.body_mut().read_to_string(&mut body).await?;
1927                anyhow::bail!(
1928                    "Request failed with status: {:?}\nBody: {}",
1929                    response.status(),
1930                    body
1931                );
1932            }
1933        }
1934    }
1935
1936    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1937    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1938
1939    // Refresh the related excerpts when the user just beguns editing after
1940    // an idle period, and after they pause editing.
1941    fn refresh_context_if_needed(
1942        &mut self,
1943        project: &Entity<Project>,
1944        buffer: &Entity<language::Buffer>,
1945        cursor_position: language::Anchor,
1946        cx: &mut Context<Self>,
1947    ) {
1948        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1949            return;
1950        }
1951
1952        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1953            return;
1954        };
1955
1956        let now = Instant::now();
1957        let was_idle = zeta_project
1958            .refresh_context_timestamp
1959            .map_or(true, |timestamp| {
1960                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1961            });
1962        zeta_project.refresh_context_timestamp = Some(now);
1963        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1964            let buffer = buffer.clone();
1965            let project = project.clone();
1966            async move |this, cx| {
1967                if was_idle {
1968                    log::debug!("refetching edit prediction context after idle");
1969                } else {
1970                    cx.background_executor()
1971                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1972                        .await;
1973                    log::debug!("refetching edit prediction context after pause");
1974                }
1975                this.update(cx, |this, cx| {
1976                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1977
1978                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1979                        zeta_project.refresh_context_task = Some(task.log_err());
1980                    };
1981                })
1982                .ok()
1983            }
1984        }));
1985    }
1986
1987    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1988    // and avoid spawning more than one concurrent task.
1989    pub fn refresh_context(
1990        &mut self,
1991        project: Entity<Project>,
1992        buffer: Entity<language::Buffer>,
1993        cursor_position: language::Anchor,
1994        cx: &mut Context<Self>,
1995    ) -> Task<Result<()>> {
1996        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1997            return Task::ready(anyhow::Ok(()));
1998        };
1999
2000        let ContextMode::Agentic(options) = &self.options().context else {
2001            return Task::ready(anyhow::Ok(()));
2002        };
2003
2004        let snapshot = buffer.read(cx).snapshot();
2005        let cursor_point = cursor_position.to_point(&snapshot);
2006        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2007            cursor_point,
2008            &snapshot,
2009            &options.excerpt,
2010            None,
2011        ) else {
2012            return Task::ready(Ok(()));
2013        };
2014
2015        let app_version = AppVersion::global(cx);
2016        let client = self.client.clone();
2017        let llm_token = self.llm_token.clone();
2018        let debug_tx = self.debug_tx.clone();
2019        let current_file_path: Arc<Path> = snapshot
2020            .file()
2021            .map(|f| f.full_path(cx).into())
2022            .unwrap_or_else(|| Path::new("untitled").into());
2023
2024        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2025            predict_edits_v3::PlanContextRetrievalRequest {
2026                excerpt: cursor_excerpt.text(&snapshot).body,
2027                excerpt_path: current_file_path,
2028                excerpt_line_range: cursor_excerpt.line_range,
2029                cursor_file_max_row: Line(snapshot.max_point().row),
2030                events: zeta_project.events(cx),
2031            },
2032        ) {
2033            Ok(prompt) => prompt,
2034            Err(err) => {
2035                return Task::ready(Err(err));
2036            }
2037        };
2038
2039        if let Some(debug_tx) = &debug_tx {
2040            debug_tx
2041                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2042                    ZetaContextRetrievalStartedDebugInfo {
2043                        project: project.clone(),
2044                        timestamp: Instant::now(),
2045                        search_prompt: prompt.clone(),
2046                    },
2047                ))
2048                .ok();
2049        }
2050
2051        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2052            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2053                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2054            );
2055
2056            let description = schema
2057                .get("description")
2058                .and_then(|description| description.as_str())
2059                .unwrap()
2060                .to_string();
2061
2062            (schema.into(), description)
2063        });
2064
2065        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2066
2067        let request = open_ai::Request {
2068            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2069            messages: vec![open_ai::RequestMessage::User {
2070                content: open_ai::MessageContent::Plain(prompt),
2071            }],
2072            stream: false,
2073            max_completion_tokens: None,
2074            stop: Default::default(),
2075            temperature: 0.7,
2076            tool_choice: None,
2077            parallel_tool_calls: None,
2078            tools: vec![open_ai::ToolDefinition::Function {
2079                function: FunctionDefinition {
2080                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2081                    description: Some(tool_description),
2082                    parameters: Some(tool_schema),
2083                },
2084            }],
2085            prompt_cache_key: None,
2086            reasoning_effort: None,
2087        };
2088
2089        #[cfg(feature = "eval-support")]
2090        let eval_cache = self.eval_cache.clone();
2091
2092        cx.spawn(async move |this, cx| {
2093            log::trace!("Sending search planning request");
2094            let response = Self::send_raw_llm_request(
2095                request,
2096                client,
2097                llm_token,
2098                app_version,
2099                #[cfg(feature = "eval-support")]
2100                eval_cache.clone(),
2101                #[cfg(feature = "eval-support")]
2102                EvalCacheEntryKind::Context,
2103            )
2104            .await;
2105            let mut response = Self::handle_api_response(&this, response, cx)?;
2106            log::trace!("Got search planning response");
2107
2108            let choice = response
2109                .choices
2110                .pop()
2111                .context("No choices in retrieval response")?;
2112            let open_ai::RequestMessage::Assistant {
2113                content: _,
2114                tool_calls,
2115            } = choice.message
2116            else {
2117                anyhow::bail!("Retrieval response didn't include an assistant message");
2118            };
2119
2120            let mut queries: Vec<SearchToolQuery> = Vec::new();
2121            for tool_call in tool_calls {
2122                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2123                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2124                    log::warn!(
2125                        "Context retrieval response tried to call an unknown tool: {}",
2126                        function.name
2127                    );
2128
2129                    continue;
2130                }
2131
2132                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2133                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2134                queries.extend(input.queries);
2135            }
2136
2137            if let Some(debug_tx) = &debug_tx {
2138                debug_tx
2139                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2140                        ZetaSearchQueryDebugInfo {
2141                            project: project.clone(),
2142                            timestamp: Instant::now(),
2143                            search_queries: queries.clone(),
2144                        },
2145                    ))
2146                    .ok();
2147            }
2148
2149            log::trace!("Running retrieval search: {queries:#?}");
2150
2151            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2152                queries,
2153                project.clone(),
2154                #[cfg(feature = "eval-support")]
2155                eval_cache,
2156                cx,
2157            )
2158            .await;
2159
2160            log::trace!("Search queries executed");
2161
2162            if let Some(debug_tx) = &debug_tx {
2163                debug_tx
2164                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2165                        ZetaContextRetrievalDebugInfo {
2166                            project: project.clone(),
2167                            timestamp: Instant::now(),
2168                        },
2169                    ))
2170                    .ok();
2171            }
2172
2173            this.update(cx, |this, _cx| {
2174                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2175                    return Ok(());
2176                };
2177                zeta_project.refresh_context_task.take();
2178                if let Some(debug_tx) = &this.debug_tx {
2179                    debug_tx
2180                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2181                            ZetaContextRetrievalDebugInfo {
2182                                project,
2183                                timestamp: Instant::now(),
2184                            },
2185                        ))
2186                        .ok();
2187                }
2188                match related_excerpts_result {
2189                    Ok(excerpts) => {
2190                        zeta_project.context = Some(excerpts);
2191                        Ok(())
2192                    }
2193                    Err(error) => Err(error),
2194                }
2195            })?
2196        })
2197    }
2198
2199    pub fn set_context(
2200        &mut self,
2201        project: Entity<Project>,
2202        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2203    ) {
2204        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2205            zeta_project.context = Some(context);
2206        }
2207    }
2208
2209    fn gather_nearby_diagnostics(
2210        cursor_offset: usize,
2211        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2212        snapshot: &BufferSnapshot,
2213        max_diagnostics_bytes: usize,
2214    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2215        // TODO: Could make this more efficient
2216        let mut diagnostic_groups = Vec::new();
2217        for (language_server_id, diagnostics) in diagnostic_sets {
2218            let mut groups = Vec::new();
2219            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2220            diagnostic_groups.extend(
2221                groups
2222                    .into_iter()
2223                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2224            );
2225        }
2226
2227        // sort by proximity to cursor
2228        diagnostic_groups.sort_by_key(|group| {
2229            let range = &group.entries[group.primary_ix].range;
2230            if range.start >= cursor_offset {
2231                range.start - cursor_offset
2232            } else if cursor_offset >= range.end {
2233                cursor_offset - range.end
2234            } else {
2235                (cursor_offset - range.start).min(range.end - cursor_offset)
2236            }
2237        });
2238
2239        let mut results = Vec::new();
2240        let mut diagnostic_groups_truncated = false;
2241        let mut diagnostics_byte_count = 0;
2242        for group in diagnostic_groups {
2243            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2244            diagnostics_byte_count += raw_value.get().len();
2245            if diagnostics_byte_count > max_diagnostics_bytes {
2246                diagnostic_groups_truncated = true;
2247                break;
2248            }
2249            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2250        }
2251
2252        (results, diagnostic_groups_truncated)
2253    }
2254
2255    // TODO: Dedupe with similar code in request_prediction?
2256    pub fn cloud_request_for_zeta_cli(
2257        &mut self,
2258        project: &Entity<Project>,
2259        buffer: &Entity<Buffer>,
2260        position: language::Anchor,
2261        cx: &mut Context<Self>,
2262    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2263        let project_state = self.projects.get(&project.entity_id());
2264
2265        let index_state = project_state.and_then(|state| {
2266            state
2267                .syntax_index
2268                .as_ref()
2269                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2270        });
2271        let options = self.options.clone();
2272        let snapshot = buffer.read(cx).snapshot();
2273        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2274            return Task::ready(Err(anyhow!("No file path for excerpt")));
2275        };
2276        let worktree_snapshots = project
2277            .read(cx)
2278            .worktrees(cx)
2279            .map(|worktree| worktree.read(cx).snapshot())
2280            .collect::<Vec<_>>();
2281
2282        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2283            let mut path = f.worktree.read(cx).absolutize(&f.path);
2284            if path.pop() { Some(path) } else { None }
2285        });
2286
2287        cx.background_spawn(async move {
2288            let index_state = if let Some(index_state) = index_state {
2289                Some(index_state.lock_owned().await)
2290            } else {
2291                None
2292            };
2293
2294            let cursor_point = position.to_point(&snapshot);
2295
2296            let debug_info = true;
2297            EditPredictionContext::gather_context(
2298                cursor_point,
2299                &snapshot,
2300                parent_abs_path.as_deref(),
2301                match &options.context {
2302                    ContextMode::Agentic(_) => {
2303                        // TODO
2304                        panic!("Llm mode not supported in zeta cli yet");
2305                    }
2306                    ContextMode::Syntax(edit_prediction_context_options) => {
2307                        edit_prediction_context_options
2308                    }
2309                },
2310                index_state.as_deref(),
2311            )
2312            .context("Failed to select excerpt")
2313            .map(|context| {
2314                make_syntax_context_cloud_request(
2315                    excerpt_path.into(),
2316                    context,
2317                    // TODO pass everything
2318                    Vec::new(),
2319                    false,
2320                    Vec::new(),
2321                    false,
2322                    None,
2323                    debug_info,
2324                    &worktree_snapshots,
2325                    index_state.as_deref(),
2326                    Some(options.max_prompt_bytes),
2327                    options.prompt_format,
2328                )
2329            })
2330        })
2331    }
2332
2333    pub fn wait_for_initial_indexing(
2334        &mut self,
2335        project: &Entity<Project>,
2336        cx: &mut Context<Self>,
2337    ) -> Task<Result<()>> {
2338        let zeta_project = self.get_or_init_zeta_project(project, cx);
2339        if let Some(syntax_index) = &zeta_project.syntax_index {
2340            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2341        } else {
2342            Task::ready(Ok(()))
2343        }
2344    }
2345
2346    fn is_file_open_source(
2347        &self,
2348        project: &Entity<Project>,
2349        file: &Arc<dyn File>,
2350        cx: &App,
2351    ) -> bool {
2352        if !file.is_local() || file.is_private() {
2353            return false;
2354        }
2355        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2356            return false;
2357        };
2358        zeta_project
2359            .license_detection_watchers
2360            .get(&file.worktree_id(cx))
2361            .as_ref()
2362            .is_some_and(|watcher| watcher.is_project_open_source())
2363    }
2364
2365    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2366        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2367    }
2368
2369    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2370        if !self.data_collection_choice.is_enabled() {
2371            return false;
2372        }
2373        events.iter().all(|event| {
2374            matches!(
2375                event.as_ref(),
2376                Event::BufferChange {
2377                    in_open_source_repo: true,
2378                    ..
2379                }
2380            )
2381        })
2382    }
2383
2384    fn load_data_collection_choice() -> DataCollectionChoice {
2385        let choice = KEY_VALUE_STORE
2386            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2387            .log_err()
2388            .flatten();
2389
2390        match choice.as_deref() {
2391            Some("true") => DataCollectionChoice::Enabled,
2392            Some("false") => DataCollectionChoice::Disabled,
2393            Some(_) => {
2394                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2395                DataCollectionChoice::NotAnswered
2396            }
2397            None => DataCollectionChoice::NotAnswered,
2398        }
2399    }
2400
2401    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2402        self.shown_predictions.iter()
2403    }
2404
2405    pub fn shown_completions_len(&self) -> usize {
2406        self.shown_predictions.len()
2407    }
2408
2409    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2410        self.rated_predictions.contains(id)
2411    }
2412
2413    pub fn rate_prediction(
2414        &mut self,
2415        prediction: &EditPrediction,
2416        rating: EditPredictionRating,
2417        feedback: String,
2418        cx: &mut Context<Self>,
2419    ) {
2420        self.rated_predictions.insert(prediction.id.clone());
2421        telemetry::event!(
2422            "Edit Prediction Rated",
2423            rating,
2424            inputs = prediction.inputs,
2425            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2426            feedback
2427        );
2428        self.client.telemetry().flush_events().detach();
2429        cx.notify();
2430    }
2431}
2432
2433pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2434    let choice = res.choices.pop()?;
2435    let output_text = match choice.message {
2436        open_ai::RequestMessage::Assistant {
2437            content: Some(open_ai::MessageContent::Plain(content)),
2438            ..
2439        } => content,
2440        open_ai::RequestMessage::Assistant {
2441            content: Some(open_ai::MessageContent::Multipart(mut content)),
2442            ..
2443        } => {
2444            if content.is_empty() {
2445                log::error!("No output from Baseten completion response");
2446                return None;
2447            }
2448
2449            match content.remove(0) {
2450                open_ai::MessagePart::Text { text } => text,
2451                open_ai::MessagePart::Image { .. } => {
2452                    log::error!("Expected text, got an image");
2453                    return None;
2454                }
2455            }
2456        }
2457        _ => {
2458            log::error!("Invalid response message: {:?}", choice.message);
2459            return None;
2460        }
2461    };
2462    Some(output_text)
2463}
2464
2465#[derive(Error, Debug)]
2466#[error(
2467    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2468)]
2469pub struct ZedUpdateRequiredError {
2470    minimum_version: Version,
2471}
2472
2473fn make_syntax_context_cloud_request(
2474    excerpt_path: Arc<Path>,
2475    context: EditPredictionContext,
2476    events: Vec<Arc<predict_edits_v3::Event>>,
2477    can_collect_data: bool,
2478    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2479    diagnostic_groups_truncated: bool,
2480    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2481    debug_info: bool,
2482    worktrees: &Vec<worktree::Snapshot>,
2483    index_state: Option<&SyntaxIndexState>,
2484    prompt_max_bytes: Option<usize>,
2485    prompt_format: PromptFormat,
2486) -> predict_edits_v3::PredictEditsRequest {
2487    let mut signatures = Vec::new();
2488    let mut declaration_to_signature_index = HashMap::default();
2489    let mut referenced_declarations = Vec::new();
2490
2491    for snippet in context.declarations {
2492        let project_entry_id = snippet.declaration.project_entry_id();
2493        let Some(path) = worktrees.iter().find_map(|worktree| {
2494            worktree.entry_for_id(project_entry_id).map(|entry| {
2495                let mut full_path = RelPathBuf::new();
2496                full_path.push(worktree.root_name());
2497                full_path.push(&entry.path);
2498                full_path
2499            })
2500        }) else {
2501            continue;
2502        };
2503
2504        let parent_index = index_state.and_then(|index_state| {
2505            snippet.declaration.parent().and_then(|parent| {
2506                add_signature(
2507                    parent,
2508                    &mut declaration_to_signature_index,
2509                    &mut signatures,
2510                    index_state,
2511                )
2512            })
2513        });
2514
2515        let (text, text_is_truncated) = snippet.declaration.item_text();
2516        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2517            path: path.as_std_path().into(),
2518            text: text.into(),
2519            range: snippet.declaration.item_line_range(),
2520            text_is_truncated,
2521            signature_range: snippet.declaration.signature_range_in_item_text(),
2522            parent_index,
2523            signature_score: snippet.score(DeclarationStyle::Signature),
2524            declaration_score: snippet.score(DeclarationStyle::Declaration),
2525            score_components: snippet.components,
2526        });
2527    }
2528
2529    let excerpt_parent = index_state.and_then(|index_state| {
2530        context
2531            .excerpt
2532            .parent_declarations
2533            .last()
2534            .and_then(|(parent, _)| {
2535                add_signature(
2536                    *parent,
2537                    &mut declaration_to_signature_index,
2538                    &mut signatures,
2539                    index_state,
2540                )
2541            })
2542    });
2543
2544    predict_edits_v3::PredictEditsRequest {
2545        excerpt_path,
2546        excerpt: context.excerpt_text.body,
2547        excerpt_line_range: context.excerpt.line_range,
2548        excerpt_range: context.excerpt.range,
2549        cursor_point: predict_edits_v3::Point {
2550            line: predict_edits_v3::Line(context.cursor_point.row),
2551            column: context.cursor_point.column,
2552        },
2553        referenced_declarations,
2554        included_files: vec![],
2555        signatures,
2556        excerpt_parent,
2557        events,
2558        can_collect_data,
2559        diagnostic_groups,
2560        diagnostic_groups_truncated,
2561        git_info,
2562        debug_info,
2563        prompt_max_bytes,
2564        prompt_format,
2565    }
2566}
2567
2568fn add_signature(
2569    declaration_id: DeclarationId,
2570    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2571    signatures: &mut Vec<Signature>,
2572    index: &SyntaxIndexState,
2573) -> Option<usize> {
2574    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2575        return Some(*signature_index);
2576    }
2577    let Some(parent_declaration) = index.declaration(declaration_id) else {
2578        log::error!("bug: missing parent declaration");
2579        return None;
2580    };
2581    let parent_index = parent_declaration.parent().and_then(|parent| {
2582        add_signature(parent, declaration_to_signature_index, signatures, index)
2583    });
2584    let (text, text_is_truncated) = parent_declaration.signature_text();
2585    let signature_index = signatures.len();
2586    signatures.push(Signature {
2587        text: text.into(),
2588        text_is_truncated,
2589        parent_index,
2590        range: parent_declaration.signature_line_range(),
2591    });
2592    declaration_to_signature_index.insert(declaration_id, signature_index);
2593    Some(signature_index)
2594}
2595
2596#[cfg(feature = "eval-support")]
2597pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2598
2599#[cfg(feature = "eval-support")]
2600#[derive(Debug, Clone, Copy, PartialEq)]
2601pub enum EvalCacheEntryKind {
2602    Context,
2603    Search,
2604    Prediction,
2605}
2606
2607#[cfg(feature = "eval-support")]
2608impl std::fmt::Display for EvalCacheEntryKind {
2609    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2610        match self {
2611            EvalCacheEntryKind::Search => write!(f, "search"),
2612            EvalCacheEntryKind::Context => write!(f, "context"),
2613            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2614        }
2615    }
2616}
2617
2618#[cfg(feature = "eval-support")]
2619pub trait EvalCache: Send + Sync {
2620    fn read(&self, key: EvalCacheKey) -> Option<String>;
2621    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2622}
2623
2624#[derive(Debug, Clone, Copy)]
2625pub enum DataCollectionChoice {
2626    NotAnswered,
2627    Enabled,
2628    Disabled,
2629}
2630
2631impl DataCollectionChoice {
2632    pub fn is_enabled(self) -> bool {
2633        match self {
2634            Self::Enabled => true,
2635            Self::NotAnswered | Self::Disabled => false,
2636        }
2637    }
2638
2639    pub fn is_answered(self) -> bool {
2640        match self {
2641            Self::Enabled | Self::Disabled => true,
2642            Self::NotAnswered => false,
2643        }
2644    }
2645
2646    #[must_use]
2647    pub fn toggle(&self) -> DataCollectionChoice {
2648        match self {
2649            Self::Enabled => Self::Disabled,
2650            Self::Disabled => Self::Enabled,
2651            Self::NotAnswered => Self::Enabled,
2652        }
2653    }
2654}
2655
2656impl From<bool> for DataCollectionChoice {
2657    fn from(value: bool) -> Self {
2658        match value {
2659            true => DataCollectionChoice::Enabled,
2660            false => DataCollectionChoice::Disabled,
2661        }
2662    }
2663}
2664
2665struct ZedPredictUpsell;
2666
2667impl Dismissable for ZedPredictUpsell {
2668    const KEY: &'static str = "dismissed-edit-predict-upsell";
2669
2670    fn dismissed() -> bool {
2671        // To make this backwards compatible with older versions of Zed, we
2672        // check if the user has seen the previous Edit Prediction Onboarding
2673        // before, by checking the data collection choice which was written to
2674        // the database once the user clicked on "Accept and Enable"
2675        if KEY_VALUE_STORE
2676            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2677            .log_err()
2678            .is_some_and(|s| s.is_some())
2679        {
2680            return true;
2681        }
2682
2683        KEY_VALUE_STORE
2684            .read_kvp(Self::KEY)
2685            .log_err()
2686            .is_some_and(|s| s.is_some())
2687    }
2688}
2689
2690pub fn should_show_upsell_modal() -> bool {
2691    !ZedPredictUpsell::dismissed()
2692}
2693
2694pub fn init(cx: &mut App) {
2695    feature_gate_predict_edits_actions(cx);
2696
2697    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2698        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2699            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2700                RatePredictionsModal::toggle(workspace, window, cx);
2701            }
2702        });
2703
2704        workspace.register_action(
2705            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2706                ZedPredictModal::toggle(
2707                    workspace,
2708                    workspace.user_store().clone(),
2709                    workspace.client().clone(),
2710                    window,
2711                    cx,
2712                )
2713            },
2714        );
2715
2716        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2717            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2718                settings
2719                    .project
2720                    .all_languages
2721                    .features
2722                    .get_or_insert_default()
2723                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2724            });
2725        });
2726    })
2727    .detach();
2728}
2729
2730fn feature_gate_predict_edits_actions(cx: &mut App) {
2731    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2732    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2733    let zeta_all_action_types = [
2734        TypeId::of::<RateCompletions>(),
2735        TypeId::of::<ResetOnboarding>(),
2736        zed_actions::OpenZedPredictOnboarding.type_id(),
2737        TypeId::of::<ClearHistory>(),
2738        TypeId::of::<ThumbsUpActivePrediction>(),
2739        TypeId::of::<ThumbsDownActivePrediction>(),
2740        TypeId::of::<NextEdit>(),
2741        TypeId::of::<PreviousEdit>(),
2742    ];
2743
2744    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2745        filter.hide_action_types(&rate_completion_action_types);
2746        filter.hide_action_types(&reset_onboarding_action_types);
2747        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2748    });
2749
2750    cx.observe_global::<SettingsStore>(move |cx| {
2751        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2752        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2753
2754        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2755            if is_ai_disabled {
2756                filter.hide_action_types(&zeta_all_action_types);
2757            } else if has_feature_flag {
2758                filter.show_action_types(&rate_completion_action_types);
2759            } else {
2760                filter.hide_action_types(&rate_completion_action_types);
2761            }
2762        });
2763    })
2764    .detach();
2765
2766    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2767        if !DisableAiSettings::get_global(cx).disable_ai {
2768            if is_enabled {
2769                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2770                    filter.show_action_types(&rate_completion_action_types);
2771                });
2772            } else {
2773                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2774                    filter.hide_action_types(&rate_completion_action_types);
2775                });
2776            }
2777        }
2778    })
2779    .detach();
2780}
2781
2782#[cfg(test)]
2783mod tests {
2784    use std::{path::Path, sync::Arc};
2785
2786    use client::UserStore;
2787    use clock::FakeSystemClock;
2788    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2789    use futures::{
2790        AsyncReadExt, StreamExt,
2791        channel::{mpsc, oneshot},
2792    };
2793    use gpui::{
2794        Entity, TestAppContext,
2795        http_client::{FakeHttpClient, Response},
2796        prelude::*,
2797    };
2798    use indoc::indoc;
2799    use language::OffsetRangeExt as _;
2800    use open_ai::Usage;
2801    use pretty_assertions::{assert_eq, assert_matches};
2802    use project::{FakeFs, Project};
2803    use serde_json::json;
2804    use settings::SettingsStore;
2805    use util::path;
2806    use uuid::Uuid;
2807
2808    use crate::{BufferEditPrediction, Zeta};
2809
2810    #[gpui::test]
2811    async fn test_current_state(cx: &mut TestAppContext) {
2812        let (zeta, mut req_rx) = init_test(cx);
2813        let fs = FakeFs::new(cx.executor());
2814        fs.insert_tree(
2815            "/root",
2816            json!({
2817                "1.txt": "Hello!\nHow\nBye\n",
2818                "2.txt": "Hola!\nComo\nAdios\n"
2819            }),
2820        )
2821        .await;
2822        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2823
2824        zeta.update(cx, |zeta, cx| {
2825            zeta.register_project(&project, cx);
2826        });
2827
2828        let buffer1 = project
2829            .update(cx, |project, cx| {
2830                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2831                project.open_buffer(path, cx)
2832            })
2833            .await
2834            .unwrap();
2835        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2836        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2837
2838        // Prediction for current file
2839
2840        zeta.update(cx, |zeta, cx| {
2841            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2842        });
2843        let (_request, respond_tx) = req_rx.next().await.unwrap();
2844
2845        respond_tx
2846            .send(model_response(indoc! {r"
2847                --- a/root/1.txt
2848                +++ b/root/1.txt
2849                @@ ... @@
2850                 Hello!
2851                -How
2852                +How are you?
2853                 Bye
2854            "}))
2855            .unwrap();
2856
2857        cx.run_until_parked();
2858
2859        zeta.read_with(cx, |zeta, cx| {
2860            let prediction = zeta
2861                .current_prediction_for_buffer(&buffer1, &project, cx)
2862                .unwrap();
2863            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2864        });
2865
2866        // Context refresh
2867        let refresh_task = zeta.update(cx, |zeta, cx| {
2868            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2869        });
2870        let (_request, respond_tx) = req_rx.next().await.unwrap();
2871        respond_tx
2872            .send(open_ai::Response {
2873                id: Uuid::new_v4().to_string(),
2874                object: "response".into(),
2875                created: 0,
2876                model: "model".into(),
2877                choices: vec![open_ai::Choice {
2878                    index: 0,
2879                    message: open_ai::RequestMessage::Assistant {
2880                        content: None,
2881                        tool_calls: vec![open_ai::ToolCall {
2882                            id: "search".into(),
2883                            content: open_ai::ToolCallContent::Function {
2884                                function: open_ai::FunctionContent {
2885                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2886                                        .to_string(),
2887                                    arguments: serde_json::to_string(&SearchToolInput {
2888                                        queries: Box::new([SearchToolQuery {
2889                                            glob: "root/2.txt".to_string(),
2890                                            syntax_node: vec![],
2891                                            content: Some(".".into()),
2892                                        }]),
2893                                    })
2894                                    .unwrap(),
2895                                },
2896                            },
2897                        }],
2898                    },
2899                    finish_reason: None,
2900                }],
2901                usage: Usage {
2902                    prompt_tokens: 0,
2903                    completion_tokens: 0,
2904                    total_tokens: 0,
2905                },
2906            })
2907            .unwrap();
2908        refresh_task.await.unwrap();
2909
2910        zeta.update(cx, |zeta, cx| {
2911            zeta.discard_current_prediction(&project, cx);
2912        });
2913
2914        // Prediction for another file
2915        zeta.update(cx, |zeta, cx| {
2916            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2917        });
2918        let (_request, respond_tx) = req_rx.next().await.unwrap();
2919        respond_tx
2920            .send(model_response(indoc! {r#"
2921                --- a/root/2.txt
2922                +++ b/root/2.txt
2923                 Hola!
2924                -Como
2925                +Como estas?
2926                 Adios
2927            "#}))
2928            .unwrap();
2929        cx.run_until_parked();
2930
2931        zeta.read_with(cx, |zeta, cx| {
2932            let prediction = zeta
2933                .current_prediction_for_buffer(&buffer1, &project, cx)
2934                .unwrap();
2935            assert_matches!(
2936                prediction,
2937                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2938            );
2939        });
2940
2941        let buffer2 = project
2942            .update(cx, |project, cx| {
2943                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2944                project.open_buffer(path, cx)
2945            })
2946            .await
2947            .unwrap();
2948
2949        zeta.read_with(cx, |zeta, cx| {
2950            let prediction = zeta
2951                .current_prediction_for_buffer(&buffer2, &project, cx)
2952                .unwrap();
2953            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2954        });
2955    }
2956
2957    #[gpui::test]
2958    async fn test_simple_request(cx: &mut TestAppContext) {
2959        let (zeta, mut req_rx) = init_test(cx);
2960        let fs = FakeFs::new(cx.executor());
2961        fs.insert_tree(
2962            "/root",
2963            json!({
2964                "foo.md":  "Hello!\nHow\nBye\n"
2965            }),
2966        )
2967        .await;
2968        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2969
2970        let buffer = project
2971            .update(cx, |project, cx| {
2972                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2973                project.open_buffer(path, cx)
2974            })
2975            .await
2976            .unwrap();
2977        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2978        let position = snapshot.anchor_before(language::Point::new(1, 3));
2979
2980        let prediction_task = zeta.update(cx, |zeta, cx| {
2981            zeta.request_prediction(&project, &buffer, position, cx)
2982        });
2983
2984        let (_, respond_tx) = req_rx.next().await.unwrap();
2985
2986        // TODO Put back when we have a structured request again
2987        // assert_eq!(
2988        //     request.excerpt_path.as_ref(),
2989        //     Path::new(path!("root/foo.md"))
2990        // );
2991        // assert_eq!(
2992        //     request.cursor_point,
2993        //     Point {
2994        //         line: Line(1),
2995        //         column: 3
2996        //     }
2997        // );
2998
2999        respond_tx
3000            .send(model_response(indoc! { r"
3001                --- a/root/foo.md
3002                +++ b/root/foo.md
3003                @@ ... @@
3004                 Hello!
3005                -How
3006                +How are you?
3007                 Bye
3008            "}))
3009            .unwrap();
3010
3011        let prediction = prediction_task.await.unwrap().unwrap();
3012
3013        assert_eq!(prediction.edits.len(), 1);
3014        assert_eq!(
3015            prediction.edits[0].0.to_point(&snapshot).start,
3016            language::Point::new(1, 3)
3017        );
3018        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3019    }
3020
3021    #[gpui::test]
3022    async fn test_request_events(cx: &mut TestAppContext) {
3023        let (zeta, mut req_rx) = init_test(cx);
3024        let fs = FakeFs::new(cx.executor());
3025        fs.insert_tree(
3026            "/root",
3027            json!({
3028                "foo.md": "Hello!\n\nBye\n"
3029            }),
3030        )
3031        .await;
3032        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3033
3034        let buffer = project
3035            .update(cx, |project, cx| {
3036                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3037                project.open_buffer(path, cx)
3038            })
3039            .await
3040            .unwrap();
3041
3042        zeta.update(cx, |zeta, cx| {
3043            zeta.register_buffer(&buffer, &project, cx);
3044        });
3045
3046        buffer.update(cx, |buffer, cx| {
3047            buffer.edit(vec![(7..7, "How")], None, cx);
3048        });
3049
3050        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3051        let position = snapshot.anchor_before(language::Point::new(1, 3));
3052
3053        let prediction_task = zeta.update(cx, |zeta, cx| {
3054            zeta.request_prediction(&project, &buffer, position, cx)
3055        });
3056
3057        let (request, respond_tx) = req_rx.next().await.unwrap();
3058
3059        let prompt = prompt_from_request(&request);
3060        assert!(
3061            prompt.contains(indoc! {"
3062            --- a/root/foo.md
3063            +++ b/root/foo.md
3064            @@ -1,3 +1,3 @@
3065             Hello!
3066            -
3067            +How
3068             Bye
3069        "}),
3070            "{prompt}"
3071        );
3072
3073        respond_tx
3074            .send(model_response(indoc! {r#"
3075                --- a/root/foo.md
3076                +++ b/root/foo.md
3077                @@ ... @@
3078                 Hello!
3079                -How
3080                +How are you?
3081                 Bye
3082            "#}))
3083            .unwrap();
3084
3085        let prediction = prediction_task.await.unwrap().unwrap();
3086
3087        assert_eq!(prediction.edits.len(), 1);
3088        assert_eq!(
3089            prediction.edits[0].0.to_point(&snapshot).start,
3090            language::Point::new(1, 3)
3091        );
3092        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3093    }
3094
3095    // Skipped until we start including diagnostics in prompt
3096    // #[gpui::test]
3097    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3098    //     let (zeta, mut req_rx) = init_test(cx);
3099    //     let fs = FakeFs::new(cx.executor());
3100    //     fs.insert_tree(
3101    //         "/root",
3102    //         json!({
3103    //             "foo.md": "Hello!\nBye"
3104    //         }),
3105    //     )
3106    //     .await;
3107    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3108
3109    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3110    //     let diagnostic = lsp::Diagnostic {
3111    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3112    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3113    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3114    //         ..Default::default()
3115    //     };
3116
3117    //     project.update(cx, |project, cx| {
3118    //         project.lsp_store().update(cx, |lsp_store, cx| {
3119    //             // Create some diagnostics
3120    //             lsp_store
3121    //                 .update_diagnostics(
3122    //                     LanguageServerId(0),
3123    //                     lsp::PublishDiagnosticsParams {
3124    //                         uri: path_to_buffer_uri.clone(),
3125    //                         diagnostics: vec![diagnostic],
3126    //                         version: None,
3127    //                     },
3128    //                     None,
3129    //                     language::DiagnosticSourceKind::Pushed,
3130    //                     &[],
3131    //                     cx,
3132    //                 )
3133    //                 .unwrap();
3134    //         });
3135    //     });
3136
3137    //     let buffer = project
3138    //         .update(cx, |project, cx| {
3139    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3140    //             project.open_buffer(path, cx)
3141    //         })
3142    //         .await
3143    //         .unwrap();
3144
3145    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3146    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3147
3148    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3149    //         zeta.request_prediction(&project, &buffer, position, cx)
3150    //     });
3151
3152    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3153
3154    //     assert_eq!(request.diagnostic_groups.len(), 1);
3155    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3156    //         .unwrap();
3157    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3158    //     assert_eq!(
3159    //         value,
3160    //         json!({
3161    //             "entries": [{
3162    //                 "range": {
3163    //                     "start": 8,
3164    //                     "end": 10
3165    //                 },
3166    //                 "diagnostic": {
3167    //                     "source": null,
3168    //                     "code": null,
3169    //                     "code_description": null,
3170    //                     "severity": 1,
3171    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3172    //                     "markdown": null,
3173    //                     "group_id": 0,
3174    //                     "is_primary": true,
3175    //                     "is_disk_based": false,
3176    //                     "is_unnecessary": false,
3177    //                     "source_kind": "Pushed",
3178    //                     "data": null,
3179    //                     "underline": true
3180    //                 }
3181    //             }],
3182    //             "primary_ix": 0
3183    //         })
3184    //     );
3185    // }
3186
3187    fn model_response(text: &str) -> open_ai::Response {
3188        open_ai::Response {
3189            id: Uuid::new_v4().to_string(),
3190            object: "response".into(),
3191            created: 0,
3192            model: "model".into(),
3193            choices: vec![open_ai::Choice {
3194                index: 0,
3195                message: open_ai::RequestMessage::Assistant {
3196                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3197                    tool_calls: vec![],
3198                },
3199                finish_reason: None,
3200            }],
3201            usage: Usage {
3202                prompt_tokens: 0,
3203                completion_tokens: 0,
3204                total_tokens: 0,
3205            },
3206        }
3207    }
3208
3209    fn prompt_from_request(request: &open_ai::Request) -> &str {
3210        assert_eq!(request.messages.len(), 1);
3211        let open_ai::RequestMessage::User {
3212            content: open_ai::MessageContent::Plain(content),
3213            ..
3214        } = &request.messages[0]
3215        else {
3216            panic!(
3217                "Request does not have single user message of type Plain. {:#?}",
3218                request
3219            );
3220        };
3221        content
3222    }
3223
3224    fn init_test(
3225        cx: &mut TestAppContext,
3226    ) -> (
3227        Entity<Zeta>,
3228        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3229    ) {
3230        cx.update(move |cx| {
3231            let settings_store = SettingsStore::test(cx);
3232            cx.set_global(settings_store);
3233            zlog::init_test();
3234
3235            let (req_tx, req_rx) = mpsc::unbounded();
3236
3237            let http_client = FakeHttpClient::create({
3238                move |req| {
3239                    let uri = req.uri().path().to_string();
3240                    let mut body = req.into_body();
3241                    let req_tx = req_tx.clone();
3242                    async move {
3243                        let resp = match uri.as_str() {
3244                            "/client/llm_tokens" => serde_json::to_string(&json!({
3245                                "token": "test"
3246                            }))
3247                            .unwrap(),
3248                            "/predict_edits/raw" => {
3249                                let mut buf = Vec::new();
3250                                body.read_to_end(&mut buf).await.ok();
3251                                let req = serde_json::from_slice(&buf).unwrap();
3252
3253                                let (res_tx, res_rx) = oneshot::channel();
3254                                req_tx.unbounded_send((req, res_tx)).unwrap();
3255                                serde_json::to_string(&res_rx.await?).unwrap()
3256                            }
3257                            _ => {
3258                                panic!("Unexpected path: {}", uri)
3259                            }
3260                        };
3261
3262                        Ok(Response::builder().body(resp.into()).unwrap())
3263                    }
3264                }
3265            });
3266
3267            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3268            client.cloud_client().set_credentials(1, "test".into());
3269
3270            language_model::init(client.clone(), cx);
3271
3272            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3273            let zeta = Zeta::global(&client, &user_store, cx);
3274
3275            (zeta, req_rx)
3276        })
3277    }
3278}