zeta.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
   7    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::{HashMap, HashSet};
  13use command_palette_hooks::CommandPaletteFilter;
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{
  16    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  17    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  18    SyntaxIndex, SyntaxIndexState,
  19};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
  21use futures::channel::{mpsc, oneshot};
  22use futures::{AsyncReadExt as _, StreamExt as _};
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::{
  29    Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
  30};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, RefreshLlmTokenListener};
  33use lsp::DiagnosticSeverity;
  34use open_ai::FunctionDefinition;
  35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
  40use std::any::{Any as _, TypeId};
  41use std::collections::{VecDeque, hash_map};
  42use telemetry_events::EditPredictionRating;
  43use workspace::Workspace;
  44
  45use std::fmt::Write as _;
  46use std::ops::Range;
  47use std::path::Path;
  48use std::rc::Rc;
  49use std::str::FromStr as _;
  50use std::sync::{Arc, LazyLock};
  51use std::time::{Duration, Instant};
  52use std::{env, mem};
  53use thiserror::Error;
  54use util::rel_path::RelPathBuf;
  55use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod assemble_excerpts;
  59mod license_detection;
  60mod onboarding_modal;
  61mod prediction;
  62mod provider;
  63mod rate_prediction_modal;
  64pub mod retrieval_search;
  65mod sweep_ai;
  66pub mod udiff;
  67mod xml_edits;
  68pub mod zeta1;
  69
  70#[cfg(test)]
  71mod zeta_tests;
  72
  73use crate::assemble_excerpts::assemble_excerpts;
  74use crate::license_detection::LicenseDetectionWatcher;
  75use crate::onboarding_modal::ZedPredictModal;
  76pub use crate::prediction::EditPrediction;
  77pub use crate::prediction::EditPredictionId;
  78pub use crate::prediction::EditPredictionInputs;
  79use crate::rate_prediction_modal::{
  80    NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
  81    ThumbsUpActivePrediction,
  82};
  83use crate::zeta1::request_prediction_with_zeta1;
  84pub use provider::ZetaEditPredictionProvider;
  85
  86actions!(
  87    edit_prediction,
  88    [
  89        /// Resets the edit prediction onboarding state.
  90        ResetOnboarding,
  91        /// Opens the rate completions modal.
  92        RateCompletions,
  93        /// Clears the edit prediction history.
  94        ClearHistory,
  95    ]
  96);
  97
  98/// Maximum number of events to track.
  99const EVENT_COUNT_MAX: usize = 6;
 100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 102
 103pub struct SweepFeatureFlag;
 104
 105impl FeatureFlag for SweepFeatureFlag {
 106    const NAME: &str = "sweep-ai";
 107}
 108pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 109    max_bytes: 512,
 110    min_bytes: 128,
 111    target_before_cursor_over_total_bytes: 0.5,
 112};
 113
 114pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
 115    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
 116
 117pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
 118    excerpt: DEFAULT_EXCERPT_OPTIONS,
 119};
 120
 121pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
 122    EditPredictionContextOptions {
 123        use_imports: true,
 124        max_retrieved_declarations: 0,
 125        excerpt: DEFAULT_EXCERPT_OPTIONS,
 126        score: EditPredictionScoreOptions {
 127            omit_excerpt_overlaps: true,
 128        },
 129    };
 130
 131pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 132    context: DEFAULT_CONTEXT_OPTIONS,
 133    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 134    max_diagnostic_bytes: 2048,
 135    prompt_format: PromptFormat::DEFAULT,
 136    file_indexing_parallelism: 1,
 137    buffer_change_grouping_interval: Duration::from_secs(1),
 138};
 139
 140static USE_OLLAMA: LazyLock<bool> =
 141    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 142static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 143    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 144        "qwen3-coder:30b".to_string()
 145    } else {
 146        "yqvev8r3".to_string()
 147    })
 148});
 149static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 150    match env::var("ZED_ZETA2_MODEL").as_deref() {
 151        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 152        Ok(model) => model,
 153        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 154        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 155    }
 156    .to_string()
 157});
 158static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 159    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 160        if *USE_OLLAMA {
 161            Some("http://localhost:11434/v1/chat/completions".into())
 162        } else {
 163            None
 164        }
 165    })
 166});
 167
 168pub struct Zeta2FeatureFlag;
 169
 170impl FeatureFlag for Zeta2FeatureFlag {
 171    const NAME: &'static str = "zeta2";
 172
 173    fn enabled_for_staff() -> bool {
 174        false
 175    }
 176}
 177
 178#[derive(Clone)]
 179struct ZetaGlobal(Entity<Zeta>);
 180
 181impl Global for ZetaGlobal {}
 182
 183pub struct Zeta {
 184    client: Arc<Client>,
 185    user_store: Entity<UserStore>,
 186    llm_token: LlmApiToken,
 187    _llm_token_subscription: Subscription,
 188    projects: HashMap<EntityId, ZetaProject>,
 189    options: ZetaOptions,
 190    update_required: bool,
 191    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 192    #[cfg(feature = "eval-support")]
 193    eval_cache: Option<Arc<dyn EvalCache>>,
 194    edit_prediction_model: ZetaEditPredictionModel,
 195    sweep_api_token: Option<String>,
 196    sweep_ai_debug_info: Arc<str>,
 197    data_collection_choice: DataCollectionChoice,
 198    rejected_predictions: Vec<EditPredictionRejection>,
 199    reject_predictions_tx: mpsc::UnboundedSender<()>,
 200    reject_predictions_debounce_task: Option<Task<()>>,
 201    shown_predictions: VecDeque<EditPrediction>,
 202    rated_predictions: HashSet<EditPredictionId>,
 203}
 204
 205#[derive(Default, PartialEq, Eq)]
 206pub enum ZetaEditPredictionModel {
 207    #[default]
 208    Zeta1,
 209    Zeta2,
 210    Sweep,
 211}
 212
 213#[derive(Debug, Clone, PartialEq)]
 214pub struct ZetaOptions {
 215    pub context: ContextMode,
 216    pub max_prompt_bytes: usize,
 217    pub max_diagnostic_bytes: usize,
 218    pub prompt_format: predict_edits_v3::PromptFormat,
 219    pub file_indexing_parallelism: usize,
 220    pub buffer_change_grouping_interval: Duration,
 221}
 222
 223#[derive(Debug, Clone, PartialEq)]
 224pub enum ContextMode {
 225    Agentic(AgenticContextOptions),
 226    Syntax(EditPredictionContextOptions),
 227}
 228
 229#[derive(Debug, Clone, PartialEq)]
 230pub struct AgenticContextOptions {
 231    pub excerpt: EditPredictionExcerptOptions,
 232}
 233
 234impl ContextMode {
 235    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 236        match self {
 237            ContextMode::Agentic(options) => &options.excerpt,
 238            ContextMode::Syntax(options) => &options.excerpt,
 239        }
 240    }
 241}
 242
 243#[derive(Debug)]
 244pub enum ZetaDebugInfo {
 245    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 246    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 247    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 248    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 249    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 250}
 251
 252#[derive(Debug)]
 253pub struct ZetaContextRetrievalStartedDebugInfo {
 254    pub project: Entity<Project>,
 255    pub timestamp: Instant,
 256    pub search_prompt: String,
 257}
 258
 259#[derive(Debug)]
 260pub struct ZetaContextRetrievalDebugInfo {
 261    pub project: Entity<Project>,
 262    pub timestamp: Instant,
 263}
 264
 265#[derive(Debug)]
 266pub struct ZetaEditPredictionDebugInfo {
 267    pub inputs: EditPredictionInputs,
 268    pub retrieval_time: Duration,
 269    pub buffer: WeakEntity<Buffer>,
 270    pub position: language::Anchor,
 271    pub local_prompt: Result<String, String>,
 272    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
 273}
 274
 275#[derive(Debug)]
 276pub struct ZetaSearchQueryDebugInfo {
 277    pub project: Entity<Project>,
 278    pub timestamp: Instant,
 279    pub search_queries: Vec<SearchToolQuery>,
 280}
 281
 282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 283
 284struct ZetaProject {
 285    syntax_index: Option<Entity<SyntaxIndex>>,
 286    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 287    last_event: Option<LastEvent>,
 288    recent_paths: VecDeque<ProjectPath>,
 289    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 290    current_prediction: Option<CurrentEditPrediction>,
 291    next_pending_prediction_id: usize,
 292    pending_predictions: ArrayVec<PendingPrediction, 2>,
 293    last_prediction_refresh: Option<(EntityId, Instant)>,
 294    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 295    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 296    refresh_context_debounce_task: Option<Task<Option<()>>>,
 297    refresh_context_timestamp: Option<Instant>,
 298    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 299    _subscription: gpui::Subscription,
 300}
 301
 302impl ZetaProject {
 303    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
 304        self.events
 305            .iter()
 306            .cloned()
 307            .chain(
 308                self.last_event
 309                    .as_ref()
 310                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 311            )
 312            .collect()
 313    }
 314}
 315
 316#[derive(Debug, Clone)]
 317struct CurrentEditPrediction {
 318    pub requested_by: PredictionRequestedBy,
 319    pub prediction: EditPrediction,
 320    pub was_shown: bool,
 321}
 322
 323impl CurrentEditPrediction {
 324    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 325        let Some(new_edits) = self
 326            .prediction
 327            .interpolate(&self.prediction.buffer.read(cx))
 328        else {
 329            return false;
 330        };
 331
 332        if self.prediction.buffer != old_prediction.prediction.buffer {
 333            return true;
 334        }
 335
 336        let Some(old_edits) = old_prediction
 337            .prediction
 338            .interpolate(&old_prediction.prediction.buffer.read(cx))
 339        else {
 340            return true;
 341        };
 342
 343        let requested_by_buffer_id = self.requested_by.buffer_id();
 344
 345        // This reduces the occurrence of UI thrash from replacing edits
 346        //
 347        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 348        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 349            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 350            && old_edits.len() == 1
 351            && new_edits.len() == 1
 352        {
 353            let (old_range, old_text) = &old_edits[0];
 354            let (new_range, new_text) = &new_edits[0];
 355            new_range == old_range && new_text.starts_with(old_text.as_ref())
 356        } else {
 357            true
 358        }
 359    }
 360}
 361
 362#[derive(Debug, Clone)]
 363enum PredictionRequestedBy {
 364    DiagnosticsUpdate,
 365    Buffer(EntityId),
 366}
 367
 368impl PredictionRequestedBy {
 369    pub fn buffer_id(&self) -> Option<EntityId> {
 370        match self {
 371            PredictionRequestedBy::DiagnosticsUpdate => None,
 372            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 373        }
 374    }
 375}
 376
 377struct PendingPrediction {
 378    id: usize,
 379    task: Task<Option<EditPredictionId>>,
 380}
 381
 382/// A prediction from the perspective of a buffer.
 383#[derive(Debug)]
 384enum BufferEditPrediction<'a> {
 385    Local { prediction: &'a EditPrediction },
 386    Jump { prediction: &'a EditPrediction },
 387}
 388
 389struct RegisteredBuffer {
 390    snapshot: BufferSnapshot,
 391    _subscriptions: [gpui::Subscription; 2],
 392}
 393
 394struct LastEvent {
 395    old_snapshot: BufferSnapshot,
 396    new_snapshot: BufferSnapshot,
 397    end_edit_anchor: Option<Anchor>,
 398}
 399
 400impl LastEvent {
 401    pub fn finalize(
 402        &self,
 403        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 404        cx: &App,
 405    ) -> Option<Arc<predict_edits_v3::Event>> {
 406        let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
 407        let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 408
 409        let file = self.new_snapshot.file();
 410        let old_file = self.old_snapshot.file();
 411
 412        let in_open_source_repo = [file, old_file].iter().all(|file| {
 413            file.is_some_and(|file| {
 414                license_detection_watchers
 415                    .get(&file.worktree_id(cx))
 416                    .is_some_and(|watcher| watcher.is_project_open_source())
 417            })
 418        });
 419
 420        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 421
 422        if path == old_path && diff.is_empty() {
 423            None
 424        } else {
 425            Some(Arc::new(predict_edits_v3::Event::BufferChange {
 426                old_path,
 427                path,
 428                diff,
 429                in_open_source_repo,
 430                // TODO: Actually detect if this edit was predicted or not
 431                predicted: false,
 432            }))
 433        }
 434    }
 435}
 436
 437fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
 438    if let Some(file) = snapshot.file() {
 439        file.full_path(cx).into()
 440    } else {
 441        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 442    }
 443}
 444
 445impl Zeta {
 446    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 447        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 448    }
 449
 450    pub fn global(
 451        client: &Arc<Client>,
 452        user_store: &Entity<UserStore>,
 453        cx: &mut App,
 454    ) -> Entity<Self> {
 455        cx.try_global::<ZetaGlobal>()
 456            .map(|global| global.0.clone())
 457            .unwrap_or_else(|| {
 458                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 459                cx.set_global(ZetaGlobal(zeta.clone()));
 460                zeta
 461            })
 462    }
 463
 464    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 465        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 466        let data_collection_choice = Self::load_data_collection_choice();
 467
 468        let (reject_tx, mut reject_rx) = mpsc::unbounded();
 469        cx.spawn(async move |this, cx| {
 470            while let Some(()) = reject_rx.next().await {
 471                this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
 472                    .await
 473                    .log_err();
 474            }
 475            anyhow::Ok(())
 476        })
 477        .detach();
 478
 479        Self {
 480            projects: HashMap::default(),
 481            client,
 482            user_store,
 483            options: DEFAULT_OPTIONS,
 484            llm_token: LlmApiToken::default(),
 485            _llm_token_subscription: cx.subscribe(
 486                &refresh_llm_token_listener,
 487                |this, _listener, _event, cx| {
 488                    let client = this.client.clone();
 489                    let llm_token = this.llm_token.clone();
 490                    cx.spawn(async move |_this, _cx| {
 491                        llm_token.refresh(&client).await?;
 492                        anyhow::Ok(())
 493                    })
 494                    .detach_and_log_err(cx);
 495                },
 496            ),
 497            update_required: false,
 498            debug_tx: None,
 499            #[cfg(feature = "eval-support")]
 500            eval_cache: None,
 501            edit_prediction_model: ZetaEditPredictionModel::Zeta2,
 502            sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
 503                .context("No SWEEP_AI_TOKEN environment variable set")
 504                .log_err(),
 505            data_collection_choice,
 506            sweep_ai_debug_info: sweep_ai::debug_info(cx),
 507            rejected_predictions: Vec::new(),
 508            reject_predictions_debounce_task: None,
 509            reject_predictions_tx: reject_tx,
 510            rated_predictions: Default::default(),
 511            shown_predictions: Default::default(),
 512        }
 513    }
 514
 515    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 516        self.edit_prediction_model = model;
 517    }
 518
 519    pub fn has_sweep_api_token(&self) -> bool {
 520        self.sweep_api_token.is_some()
 521    }
 522
 523    #[cfg(feature = "eval-support")]
 524    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 525        self.eval_cache = Some(cache);
 526    }
 527
 528    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 529        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 530        self.debug_tx = Some(debug_watch_tx);
 531        debug_watch_rx
 532    }
 533
 534    pub fn options(&self) -> &ZetaOptions {
 535        &self.options
 536    }
 537
 538    pub fn set_options(&mut self, options: ZetaOptions) {
 539        self.options = options;
 540    }
 541
 542    pub fn clear_history(&mut self) {
 543        for zeta_project in self.projects.values_mut() {
 544            zeta_project.events.clear();
 545        }
 546    }
 547
 548    pub fn context_for_project(
 549        &self,
 550        project: &Entity<Project>,
 551    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 552        self.projects
 553            .get(&project.entity_id())
 554            .and_then(|project| {
 555                Some(
 556                    project
 557                        .context
 558                        .as_ref()?
 559                        .iter()
 560                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 561                )
 562            })
 563            .into_iter()
 564            .flatten()
 565    }
 566
 567    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 568        if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
 569            self.user_store.read(cx).edit_prediction_usage()
 570        } else {
 571            None
 572        }
 573    }
 574
 575    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 576        self.get_or_init_zeta_project(project, cx);
 577    }
 578
 579    pub fn register_buffer(
 580        &mut self,
 581        buffer: &Entity<Buffer>,
 582        project: &Entity<Project>,
 583        cx: &mut Context<Self>,
 584    ) {
 585        let zeta_project = self.get_or_init_zeta_project(project, cx);
 586        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 587    }
 588
 589    fn get_or_init_zeta_project(
 590        &mut self,
 591        project: &Entity<Project>,
 592        cx: &mut Context<Self>,
 593    ) -> &mut ZetaProject {
 594        self.projects
 595            .entry(project.entity_id())
 596            .or_insert_with(|| ZetaProject {
 597                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 598                    Some(cx.new(|cx| {
 599                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 600                    }))
 601                } else {
 602                    None
 603                },
 604                events: VecDeque::new(),
 605                last_event: None,
 606                recent_paths: VecDeque::new(),
 607                registered_buffers: HashMap::default(),
 608                current_prediction: None,
 609                pending_predictions: ArrayVec::new(),
 610                next_pending_prediction_id: 0,
 611                last_prediction_refresh: None,
 612                context: None,
 613                refresh_context_task: None,
 614                refresh_context_debounce_task: None,
 615                refresh_context_timestamp: None,
 616                license_detection_watchers: HashMap::default(),
 617                _subscription: cx.subscribe(&project, Self::handle_project_event),
 618            })
 619    }
 620
 621    fn handle_project_event(
 622        &mut self,
 623        project: Entity<Project>,
 624        event: &project::Event,
 625        cx: &mut Context<Self>,
 626    ) {
 627        // TODO [zeta2] init with recent paths
 628        match event {
 629            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 630                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 631                    return;
 632                };
 633                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 634                if let Some(path) = path {
 635                    if let Some(ix) = zeta_project
 636                        .recent_paths
 637                        .iter()
 638                        .position(|probe| probe == &path)
 639                    {
 640                        zeta_project.recent_paths.remove(ix);
 641                    }
 642                    zeta_project.recent_paths.push_front(path);
 643                }
 644            }
 645            project::Event::DiagnosticsUpdated { .. } => {
 646                self.refresh_prediction_from_diagnostics(project, cx);
 647            }
 648            _ => (),
 649        }
 650    }
 651
 652    fn register_buffer_impl<'a>(
 653        zeta_project: &'a mut ZetaProject,
 654        buffer: &Entity<Buffer>,
 655        project: &Entity<Project>,
 656        cx: &mut Context<Self>,
 657    ) -> &'a mut RegisteredBuffer {
 658        let buffer_id = buffer.entity_id();
 659
 660        if let Some(file) = buffer.read(cx).file() {
 661            let worktree_id = file.worktree_id(cx);
 662            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 663                zeta_project
 664                    .license_detection_watchers
 665                    .entry(worktree_id)
 666                    .or_insert_with(|| {
 667                        let project_entity_id = project.entity_id();
 668                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 669                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 670                            else {
 671                                return;
 672                            };
 673                            zeta_project.license_detection_watchers.remove(&worktree_id);
 674                        })
 675                        .detach();
 676                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 677                    });
 678            }
 679        }
 680
 681        match zeta_project.registered_buffers.entry(buffer_id) {
 682            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 683            hash_map::Entry::Vacant(entry) => {
 684                let snapshot = buffer.read(cx).snapshot();
 685                let project_entity_id = project.entity_id();
 686                entry.insert(RegisteredBuffer {
 687                    snapshot,
 688                    _subscriptions: [
 689                        cx.subscribe(buffer, {
 690                            let project = project.downgrade();
 691                            move |this, buffer, event, cx| {
 692                                if let language::BufferEvent::Edited = event
 693                                    && let Some(project) = project.upgrade()
 694                                {
 695                                    this.report_changes_for_buffer(&buffer, &project, cx);
 696                                }
 697                            }
 698                        }),
 699                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 700                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 701                            else {
 702                                return;
 703                            };
 704                            zeta_project.registered_buffers.remove(&buffer_id);
 705                        }),
 706                    ],
 707                })
 708            }
 709        }
 710    }
 711
 712    fn report_changes_for_buffer(
 713        &mut self,
 714        buffer: &Entity<Buffer>,
 715        project: &Entity<Project>,
 716        cx: &mut Context<Self>,
 717    ) {
 718        let project_state = self.get_or_init_zeta_project(project, cx);
 719        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 720
 721        let new_snapshot = buffer.read(cx).snapshot();
 722        if new_snapshot.version == registered_buffer.snapshot.version {
 723            return;
 724        }
 725
 726        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 727        let end_edit_anchor = new_snapshot
 728            .anchored_edits_since::<Point>(&old_snapshot.version)
 729            .last()
 730            .map(|(_, range)| range.end);
 731        let events = &mut project_state.events;
 732
 733        if let Some(LastEvent {
 734            new_snapshot: last_new_snapshot,
 735            end_edit_anchor: last_end_edit_anchor,
 736            ..
 737        }) = project_state.last_event.as_mut()
 738        {
 739            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 740                == last_new_snapshot.remote_id()
 741                && old_snapshot.version == last_new_snapshot.version;
 742
 743            let should_coalesce = is_next_snapshot_of_same_buffer
 744                && end_edit_anchor
 745                    .as_ref()
 746                    .zip(last_end_edit_anchor.as_ref())
 747                    .is_some_and(|(a, b)| {
 748                        let a = a.to_point(&new_snapshot);
 749                        let b = b.to_point(&new_snapshot);
 750                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 751                    });
 752
 753            if should_coalesce {
 754                *last_end_edit_anchor = end_edit_anchor;
 755                *last_new_snapshot = new_snapshot;
 756                return;
 757            }
 758        }
 759
 760        if events.len() + 1 >= EVENT_COUNT_MAX {
 761            events.pop_front();
 762        }
 763
 764        if let Some(event) = project_state.last_event.take() {
 765            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 766        }
 767
 768        project_state.last_event = Some(LastEvent {
 769            old_snapshot,
 770            new_snapshot,
 771            end_edit_anchor,
 772        });
 773    }
 774
 775    fn current_prediction_for_buffer(
 776        &self,
 777        buffer: &Entity<Buffer>,
 778        project: &Entity<Project>,
 779        cx: &App,
 780    ) -> Option<BufferEditPrediction<'_>> {
 781        let project_state = self.projects.get(&project.entity_id())?;
 782
 783        let CurrentEditPrediction {
 784            requested_by,
 785            prediction,
 786            ..
 787        } = project_state.current_prediction.as_ref()?;
 788
 789        if prediction.targets_buffer(buffer.read(cx)) {
 790            Some(BufferEditPrediction::Local { prediction })
 791        } else {
 792            let show_jump = match requested_by {
 793                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 794                    requested_by_buffer_id == &buffer.entity_id()
 795                }
 796                PredictionRequestedBy::DiagnosticsUpdate => true,
 797            };
 798
 799            if show_jump {
 800                Some(BufferEditPrediction::Jump { prediction })
 801            } else {
 802                None
 803            }
 804        }
 805    }
 806
 807    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 808        match self.edit_prediction_model {
 809            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 810            ZetaEditPredictionModel::Sweep => return,
 811        }
 812
 813        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 814            return;
 815        };
 816
 817        let Some(prediction) = project_state.current_prediction.take() else {
 818            return;
 819        };
 820        let request_id = prediction.prediction.id.to_string();
 821        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
 822            self.cancel_pending_prediction(pending_prediction, cx);
 823        }
 824
 825        let client = self.client.clone();
 826        let llm_token = self.llm_token.clone();
 827        let app_version = AppVersion::global(cx);
 828        cx.spawn(async move |this, cx| {
 829            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 830                http_client::Url::parse(&predict_edits_url)?
 831            } else {
 832                client
 833                    .http_client()
 834                    .build_zed_llm_url("/predict_edits/accept", &[])?
 835            };
 836
 837            let response = cx
 838                .background_spawn(Self::send_api_request::<()>(
 839                    move |builder| {
 840                        let req = builder.uri(url.as_ref()).body(
 841                            serde_json::to_string(&AcceptEditPredictionBody {
 842                                request_id: request_id.clone(),
 843                            })?
 844                            .into(),
 845                        );
 846                        Ok(req?)
 847                    },
 848                    client,
 849                    llm_token,
 850                    app_version,
 851                ))
 852                .await;
 853
 854            Self::handle_api_response(&this, response, cx)?;
 855            anyhow::Ok(())
 856        })
 857        .detach_and_log_err(cx);
 858    }
 859
 860    fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 861        match self.edit_prediction_model {
 862            ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
 863            ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
 864        }
 865
 866        let client = self.client.clone();
 867        let llm_token = self.llm_token.clone();
 868        let app_version = AppVersion::global(cx);
 869        let last_rejection = self.rejected_predictions.last().cloned();
 870        let Some(last_rejection) = last_rejection else {
 871            return Task::ready(anyhow::Ok(()));
 872        };
 873
 874        let body = serde_json::to_string(&RejectEditPredictionsBody {
 875            rejections: self.rejected_predictions.clone(),
 876        })
 877        .ok();
 878
 879        cx.spawn(async move |this, cx| {
 880            let url = client
 881                .http_client()
 882                .build_zed_llm_url("/predict_edits/reject", &[])?;
 883
 884            cx.background_spawn(Self::send_api_request::<()>(
 885                move |builder| {
 886                    let req = builder.uri(url.as_ref()).body(body.clone().into());
 887                    Ok(req?)
 888                },
 889                client,
 890                llm_token,
 891                app_version,
 892            ))
 893            .await
 894            .context("Failed to reject edit predictions")?;
 895
 896            this.update(cx, |this, _| {
 897                if let Some(ix) = this
 898                    .rejected_predictions
 899                    .iter()
 900                    .position(|rejection| rejection.request_id == last_rejection.request_id)
 901                {
 902                    this.rejected_predictions.drain(..ix + 1);
 903                }
 904            })
 905        })
 906    }
 907
 908    fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 909        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 910            project_state.pending_predictions.clear();
 911            if let Some(prediction) = project_state.current_prediction.take() {
 912                self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
 913            }
 914        };
 915    }
 916
 917    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
 918        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 919            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
 920                if !current_prediction.was_shown {
 921                    current_prediction.was_shown = true;
 922                    self.shown_predictions
 923                        .push_front(current_prediction.prediction.clone());
 924                    if self.shown_predictions.len() > 50 {
 925                        let completion = self.shown_predictions.pop_back().unwrap();
 926                        self.rated_predictions.remove(&completion.id);
 927                    }
 928                }
 929            }
 930        }
 931    }
 932
 933    fn discard_prediction(
 934        &mut self,
 935        prediction_id: EditPredictionId,
 936        was_shown: bool,
 937        cx: &mut Context<Self>,
 938    ) {
 939        self.rejected_predictions.push(EditPredictionRejection {
 940            request_id: prediction_id.to_string(),
 941            was_shown,
 942        });
 943
 944        let reached_request_limit =
 945            self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
 946        let reject_tx = self.reject_predictions_tx.clone();
 947        self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
 948            const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
 949            if !reached_request_limit {
 950                cx.background_executor()
 951                    .timer(DISCARD_COMPLETIONS_DEBOUNCE)
 952                    .await;
 953            }
 954            reject_tx.unbounded_send(()).log_err();
 955        }));
 956    }
 957
 958    fn cancel_pending_prediction(
 959        &self,
 960        pending_prediction: PendingPrediction,
 961        cx: &mut Context<Self>,
 962    ) {
 963        cx.spawn(async move |this, cx| {
 964            let Some(prediction_id) = pending_prediction.task.await else {
 965                return;
 966            };
 967
 968            this.update(cx, |this, cx| {
 969                this.discard_prediction(prediction_id, false, cx);
 970            })
 971            .ok();
 972        })
 973        .detach()
 974    }
 975
 976    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
 977        self.projects
 978            .get(&project.entity_id())
 979            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
 980    }
 981
 982    pub fn refresh_prediction_from_buffer(
 983        &mut self,
 984        project: Entity<Project>,
 985        buffer: Entity<Buffer>,
 986        position: language::Anchor,
 987        cx: &mut Context<Self>,
 988    ) {
 989        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
 990            let Some(request_task) = this
 991                .update(cx, |this, cx| {
 992                    this.request_prediction(&project, &buffer, position, cx)
 993                })
 994                .log_err()
 995            else {
 996                return Task::ready(anyhow::Ok(None));
 997            };
 998
 999            let project = project.clone();
1000            cx.spawn(async move |cx| {
1001                if let Some(prediction) = request_task.await? {
1002                    let id = prediction.id.clone();
1003                    this.update(cx, |this, cx| {
1004                        let project_state = this
1005                            .projects
1006                            .get_mut(&project.entity_id())
1007                            .context("Project not found")?;
1008
1009                        let new_prediction = CurrentEditPrediction {
1010                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
1011                            prediction: prediction,
1012                            was_shown: false,
1013                        };
1014
1015                        if project_state
1016                            .current_prediction
1017                            .as_ref()
1018                            .is_none_or(|old_prediction| {
1019                                new_prediction.should_replace_prediction(&old_prediction, cx)
1020                            })
1021                        {
1022                            project_state.current_prediction = Some(new_prediction);
1023                            cx.notify();
1024                        }
1025                        anyhow::Ok(())
1026                    })??;
1027                    Ok(Some(id))
1028                } else {
1029                    Ok(None)
1030                }
1031            })
1032        })
1033    }
1034
1035    pub fn refresh_prediction_from_diagnostics(
1036        &mut self,
1037        project: Entity<Project>,
1038        cx: &mut Context<Self>,
1039    ) {
1040        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1041            return;
1042        };
1043
1044        // Prefer predictions from buffer
1045        if zeta_project.current_prediction.is_some() {
1046            return;
1047        };
1048
1049        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1050            let Some(open_buffer_task) = project
1051                .update(cx, |project, cx| {
1052                    project
1053                        .active_entry()
1054                        .and_then(|entry| project.path_for_entry(entry, cx))
1055                        .map(|path| project.open_buffer(path, cx))
1056                })
1057                .log_err()
1058                .flatten()
1059            else {
1060                return Task::ready(anyhow::Ok(None));
1061            };
1062
1063            cx.spawn(async move |cx| {
1064                let active_buffer = open_buffer_task.await?;
1065                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1066
1067                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1068                    active_buffer,
1069                    &snapshot,
1070                    Default::default(),
1071                    Default::default(),
1072                    &project,
1073                    cx,
1074                )
1075                .await?
1076                else {
1077                    return anyhow::Ok(None);
1078                };
1079
1080                let Some(prediction) = this
1081                    .update(cx, |this, cx| {
1082                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
1083                    })?
1084                    .await?
1085                else {
1086                    return anyhow::Ok(None);
1087                };
1088
1089                let id = prediction.id.clone();
1090                this.update(cx, |this, cx| {
1091                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1092                        zeta_project.current_prediction.get_or_insert_with(|| {
1093                            cx.notify();
1094                            CurrentEditPrediction {
1095                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
1096                                prediction,
1097                                was_shown: false,
1098                            }
1099                        });
1100                    }
1101                })?;
1102
1103                anyhow::Ok(Some(id))
1104            })
1105        });
1106    }
1107
1108    #[cfg(not(test))]
1109    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1110    #[cfg(test)]
1111    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1112
1113    fn queue_prediction_refresh(
1114        &mut self,
1115        project: Entity<Project>,
1116        throttle_entity: EntityId,
1117        cx: &mut Context<Self>,
1118        do_refresh: impl FnOnce(
1119            WeakEntity<Self>,
1120            &mut AsyncApp,
1121        ) -> Task<Result<Option<EditPredictionId>>>
1122        + 'static,
1123    ) {
1124        let zeta_project = self.get_or_init_zeta_project(&project, cx);
1125        let pending_prediction_id = zeta_project.next_pending_prediction_id;
1126        zeta_project.next_pending_prediction_id += 1;
1127        let last_request = zeta_project.last_prediction_refresh;
1128
1129        // TODO report cancelled requests like in zeta1
1130        let task = cx.spawn(async move |this, cx| {
1131            if let Some((last_entity, last_timestamp)) = last_request
1132                && throttle_entity == last_entity
1133                && let Some(timeout) =
1134                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1135            {
1136                cx.background_executor().timer(timeout).await;
1137            }
1138
1139            let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
1140
1141            // When a prediction completes, remove it from the pending list, and cancel
1142            // any pending predictions that were enqueued before it.
1143            this.update(cx, |this, cx| {
1144                let zeta_project = this.get_or_init_zeta_project(&project, cx);
1145                let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1146                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1147                    if pending_prediction.id == pending_prediction_id {
1148                        pending_predictions.remove(ix);
1149                        for pending_prediction in pending_predictions.drain(0..ix) {
1150                            this.cancel_pending_prediction(pending_prediction, cx)
1151                        }
1152                        break;
1153                    }
1154                }
1155                this.get_or_init_zeta_project(&project, cx)
1156                    .pending_predictions = pending_predictions;
1157                cx.notify();
1158            })
1159            .ok();
1160
1161            edit_prediction_id
1162        });
1163
1164        if zeta_project.pending_predictions.len() <= 1 {
1165            zeta_project.pending_predictions.push(PendingPrediction {
1166                id: pending_prediction_id,
1167                task,
1168            });
1169        } else if zeta_project.pending_predictions.len() == 2 {
1170            let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1171            zeta_project.pending_predictions.push(PendingPrediction {
1172                id: pending_prediction_id,
1173                task,
1174            });
1175            self.cancel_pending_prediction(pending_prediction, cx);
1176        }
1177    }
1178
1179    pub fn request_prediction(
1180        &mut self,
1181        project: &Entity<Project>,
1182        active_buffer: &Entity<Buffer>,
1183        position: language::Anchor,
1184        cx: &mut Context<Self>,
1185    ) -> Task<Result<Option<EditPrediction>>> {
1186        match self.edit_prediction_model {
1187            ZetaEditPredictionModel::Zeta1 => {
1188                request_prediction_with_zeta1(self, project, active_buffer, position, cx)
1189            }
1190            ZetaEditPredictionModel::Zeta2 => {
1191                self.request_prediction_with_zeta2(project, active_buffer, position, cx)
1192            }
1193            ZetaEditPredictionModel::Sweep => {
1194                self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
1195            }
1196        }
1197    }
1198
1199    fn request_prediction_with_sweep(
1200        &mut self,
1201        project: &Entity<Project>,
1202        active_buffer: &Entity<Buffer>,
1203        position: language::Anchor,
1204        allow_jump: bool,
1205        cx: &mut Context<Self>,
1206    ) -> Task<Result<Option<EditPrediction>>> {
1207        let snapshot = active_buffer.read(cx).snapshot();
1208        let debug_info = self.sweep_ai_debug_info.clone();
1209        let Some(api_token) = self.sweep_api_token.clone() else {
1210            return Task::ready(Ok(None));
1211        };
1212        let full_path: Arc<Path> = snapshot
1213            .file()
1214            .map(|file| file.full_path(cx))
1215            .unwrap_or_else(|| "untitled".into())
1216            .into();
1217
1218        let project_file = project::File::from_dyn(snapshot.file());
1219        let repo_name = project_file
1220            .map(|file| file.worktree.read(cx).root_name_str())
1221            .unwrap_or("untitled")
1222            .into();
1223        let offset = position.to_offset(&snapshot);
1224
1225        let project_state = self.get_or_init_zeta_project(project, cx);
1226        let events = project_state.events(cx);
1227        let has_events = !events.is_empty();
1228        let recent_buffers = project_state.recent_paths.iter().cloned();
1229        let http_client = cx.http_client();
1230
1231        let recent_buffer_snapshots = recent_buffers
1232            .filter_map(|project_path| {
1233                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1234                if active_buffer == &buffer {
1235                    None
1236                } else {
1237                    Some(buffer.read(cx).snapshot())
1238                }
1239            })
1240            .take(3)
1241            .collect::<Vec<_>>();
1242
1243        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1244
1245        let cursor_point = position.to_point(&snapshot);
1246        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1247        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1248        let diagnostic_search_range =
1249            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1250        let buffer_snapshotted_at = Instant::now();
1251
1252        let result = cx.background_spawn({
1253            let snapshot = snapshot.clone();
1254            let diagnostic_search_range = diagnostic_search_range.clone();
1255            async move {
1256                let text = snapshot.text();
1257
1258                let mut recent_changes = String::new();
1259                for event in &events {
1260                    sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
1261                }
1262
1263                let mut file_chunks = recent_buffer_snapshots
1264                    .into_iter()
1265                    .map(|snapshot| {
1266                        let end_point = Point::new(30, 0).min(snapshot.max_point());
1267                        sweep_ai::FileChunk {
1268                            content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1269                            file_path: snapshot
1270                                .file()
1271                                .map(|f| f.path().as_unix_str())
1272                                .unwrap_or("untitled")
1273                                .to_string(),
1274                            start_line: 0,
1275                            end_line: end_point.row as usize,
1276                            timestamp: snapshot.file().and_then(|file| {
1277                                Some(
1278                                    file.disk_state()
1279                                        .mtime()?
1280                                        .to_seconds_and_nanos_for_persistence()?
1281                                        .0,
1282                                )
1283                            }),
1284                        }
1285                    })
1286                    .collect::<Vec<_>>();
1287
1288                let diagnostic_entries =
1289                    snapshot.diagnostics_in_range(diagnostic_search_range, false);
1290                let mut diagnostic_content = String::new();
1291                let mut diagnostic_count = 0;
1292
1293                for entry in diagnostic_entries {
1294                    let start_point: Point = entry.range.start;
1295
1296                    let severity = match entry.diagnostic.severity {
1297                        DiagnosticSeverity::ERROR => "error",
1298                        DiagnosticSeverity::WARNING => "warning",
1299                        DiagnosticSeverity::INFORMATION => "info",
1300                        DiagnosticSeverity::HINT => "hint",
1301                        _ => continue,
1302                    };
1303
1304                    diagnostic_count += 1;
1305
1306                    writeln!(
1307                        &mut diagnostic_content,
1308                        "{} at line {}: {}",
1309                        severity,
1310                        start_point.row + 1,
1311                        entry.diagnostic.message
1312                    )?;
1313                }
1314
1315                if !diagnostic_content.is_empty() {
1316                    file_chunks.push(sweep_ai::FileChunk {
1317                        file_path: format!("Diagnostics for {}", full_path.display()),
1318                        start_line: 0,
1319                        end_line: diagnostic_count,
1320                        content: diagnostic_content,
1321                        timestamp: None,
1322                    });
1323                }
1324
1325                let request_body = sweep_ai::AutocompleteRequest {
1326                    debug_info,
1327                    repo_name,
1328                    file_path: full_path.clone(),
1329                    file_contents: text.clone(),
1330                    original_file_contents: text,
1331                    cursor_position: offset,
1332                    recent_changes: recent_changes.clone(),
1333                    changes_above_cursor: true,
1334                    multiple_suggestions: false,
1335                    branch: None,
1336                    file_chunks,
1337                    retrieval_chunks: vec![],
1338                    recent_user_actions: vec![],
1339                    // TODO
1340                    privacy_mode_enabled: false,
1341                };
1342
1343                let mut buf: Vec<u8> = Vec::new();
1344                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1345                serde_json::to_writer(writer, &request_body)?;
1346                let body: AsyncBody = buf.into();
1347
1348                let inputs = EditPredictionInputs {
1349                    events,
1350                    included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
1351                        path: full_path.clone(),
1352                        max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
1353                        excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
1354                            start_line: cloud_llm_client::predict_edits_v3::Line(0),
1355                            text: request_body.file_contents.into(),
1356                        }],
1357                    }],
1358                    cursor_point: cloud_llm_client::predict_edits_v3::Point {
1359                        column: cursor_point.column,
1360                        line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
1361                    },
1362                    cursor_path: full_path.clone(),
1363                };
1364
1365                const SWEEP_API_URL: &str =
1366                    "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1367
1368                let request = http_client::Request::builder()
1369                    .uri(SWEEP_API_URL)
1370                    .header("Content-Type", "application/json")
1371                    .header("Authorization", format!("Bearer {}", api_token))
1372                    .header("Connection", "keep-alive")
1373                    .header("Content-Encoding", "br")
1374                    .method(Method::POST)
1375                    .body(body)?;
1376
1377                let mut response = http_client.send(request).await?;
1378
1379                let mut body: Vec<u8> = Vec::new();
1380                response.body_mut().read_to_end(&mut body).await?;
1381
1382                let response_received_at = Instant::now();
1383                if !response.status().is_success() {
1384                    anyhow::bail!(
1385                        "Request failed with status: {:?}\nBody: {}",
1386                        response.status(),
1387                        String::from_utf8_lossy(&body),
1388                    );
1389                };
1390
1391                let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1392
1393                let old_text = snapshot
1394                    .text_for_range(response.start_index..response.end_index)
1395                    .collect::<String>();
1396                let edits = language::text_diff(&old_text, &response.completion)
1397                    .into_iter()
1398                    .map(|(range, text)| {
1399                        (
1400                            snapshot.anchor_after(response.start_index + range.start)
1401                                ..snapshot.anchor_before(response.start_index + range.end),
1402                            text,
1403                        )
1404                    })
1405                    .collect::<Vec<_>>();
1406
1407                anyhow::Ok((
1408                    response.autocomplete_id,
1409                    edits,
1410                    snapshot,
1411                    response_received_at,
1412                    inputs,
1413                ))
1414            }
1415        });
1416
1417        let buffer = active_buffer.clone();
1418        let project = project.clone();
1419        let active_buffer = active_buffer.clone();
1420
1421        cx.spawn(async move |this, cx| {
1422            let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
1423
1424            if edits.is_empty() {
1425                if has_events
1426                    && allow_jump
1427                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1428                        active_buffer,
1429                        &snapshot,
1430                        diagnostic_search_range,
1431                        cursor_point,
1432                        &project,
1433                        cx,
1434                    )
1435                    .await?
1436                {
1437                    return this
1438                        .update(cx, |this, cx| {
1439                            this.request_prediction_with_sweep(
1440                                &project,
1441                                &jump_buffer,
1442                                jump_position,
1443                                false,
1444                                cx,
1445                            )
1446                        })?
1447                        .await;
1448                }
1449
1450                return anyhow::Ok(None);
1451            }
1452
1453            anyhow::Ok(
1454                EditPrediction::new(
1455                    EditPredictionId(id.into()),
1456                    &buffer,
1457                    &old_snapshot,
1458                    edits.into(),
1459                    buffer_snapshotted_at,
1460                    response_received_at,
1461                    inputs,
1462                    cx,
1463                )
1464                .await,
1465            )
1466        })
1467    }
1468
1469    async fn next_diagnostic_location(
1470        active_buffer: Entity<Buffer>,
1471        active_buffer_snapshot: &BufferSnapshot,
1472        active_buffer_diagnostic_search_range: Range<Point>,
1473        active_buffer_cursor_point: Point,
1474        project: &Entity<Project>,
1475        cx: &mut AsyncApp,
1476    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1477        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1478        let mut jump_location = active_buffer_snapshot
1479            .diagnostic_groups(None)
1480            .into_iter()
1481            .filter_map(|(_, group)| {
1482                let range = &group.entries[group.primary_ix]
1483                    .range
1484                    .to_point(&active_buffer_snapshot);
1485                if range.overlaps(&active_buffer_diagnostic_search_range) {
1486                    None
1487                } else {
1488                    Some(range.start)
1489                }
1490            })
1491            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1492            .map(|position| {
1493                (
1494                    active_buffer.clone(),
1495                    active_buffer_snapshot.anchor_before(position),
1496                )
1497            });
1498
1499        if jump_location.is_none() {
1500            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1501                let file = buffer.file()?;
1502
1503                Some(ProjectPath {
1504                    worktree_id: file.worktree_id(cx),
1505                    path: file.path().clone(),
1506                })
1507            })?;
1508
1509            let buffer_task = project.update(cx, |project, cx| {
1510                let (path, _, _) = project
1511                    .diagnostic_summaries(false, cx)
1512                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1513                    .max_by_key(|(path, _, _)| {
1514                        // find the buffer with errors that shares most parent directories
1515                        path.path
1516                            .components()
1517                            .zip(
1518                                active_buffer_path
1519                                    .as_ref()
1520                                    .map(|p| p.path.components())
1521                                    .unwrap_or_default(),
1522                            )
1523                            .take_while(|(a, b)| a == b)
1524                            .count()
1525                    })?;
1526
1527                Some(project.open_buffer(path, cx))
1528            })?;
1529
1530            if let Some(buffer_task) = buffer_task {
1531                let closest_buffer = buffer_task.await?;
1532
1533                jump_location = closest_buffer
1534                    .read_with(cx, |buffer, _cx| {
1535                        buffer
1536                            .buffer_diagnostics(None)
1537                            .into_iter()
1538                            .min_by_key(|entry| entry.diagnostic.severity)
1539                            .map(|entry| entry.range.start)
1540                    })?
1541                    .map(|position| (closest_buffer, position));
1542            }
1543        }
1544
1545        anyhow::Ok(jump_location)
1546    }
1547
1548    fn request_prediction_with_zeta2(
1549        &mut self,
1550        project: &Entity<Project>,
1551        active_buffer: &Entity<Buffer>,
1552        position: language::Anchor,
1553        cx: &mut Context<Self>,
1554    ) -> Task<Result<Option<EditPrediction>>> {
1555        let project_state = self.projects.get(&project.entity_id());
1556
1557        let index_state = project_state.and_then(|state| {
1558            state
1559                .syntax_index
1560                .as_ref()
1561                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1562        });
1563        let options = self.options.clone();
1564        let active_snapshot = active_buffer.read(cx).snapshot();
1565        let buffer_snapshotted_at = Instant::now();
1566        let Some(excerpt_path) = active_snapshot
1567            .file()
1568            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1569        else {
1570            return Task::ready(Err(anyhow!("No file path for excerpt")));
1571        };
1572        let client = self.client.clone();
1573        let llm_token = self.llm_token.clone();
1574        let app_version = AppVersion::global(cx);
1575        let worktree_snapshots = project
1576            .read(cx)
1577            .worktrees(cx)
1578            .map(|worktree| worktree.read(cx).snapshot())
1579            .collect::<Vec<_>>();
1580        let debug_tx = self.debug_tx.clone();
1581
1582        let events = project_state
1583            .map(|state| state.events(cx))
1584            .unwrap_or_default();
1585
1586        let diagnostics = active_snapshot.diagnostic_sets().clone();
1587
1588        let file = active_buffer.read(cx).file();
1589        let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1590            let mut path = f.worktree.read(cx).absolutize(&f.path);
1591            if path.pop() { Some(path) } else { None }
1592        });
1593
1594        // TODO data collection
1595        let can_collect_data = file
1596            .as_ref()
1597            .map_or(false, |file| self.can_collect_file(project, file, cx));
1598
1599        let empty_context_files = HashMap::default();
1600        let context_files = project_state
1601            .and_then(|project_state| project_state.context.as_ref())
1602            .unwrap_or(&empty_context_files);
1603
1604        #[cfg(feature = "eval-support")]
1605        let parsed_fut = futures::future::join_all(
1606            context_files
1607                .keys()
1608                .map(|buffer| buffer.read(cx).parsing_idle()),
1609        );
1610
1611        let mut included_files = context_files
1612            .iter()
1613            .filter_map(|(buffer_entity, ranges)| {
1614                let buffer = buffer_entity.read(cx);
1615                Some((
1616                    buffer_entity.clone(),
1617                    buffer.snapshot(),
1618                    buffer.file()?.full_path(cx).into(),
1619                    ranges.clone(),
1620                ))
1621            })
1622            .collect::<Vec<_>>();
1623
1624        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1625            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1626        });
1627
1628        #[cfg(feature = "eval-support")]
1629        let eval_cache = self.eval_cache.clone();
1630
1631        let request_task = cx.background_spawn({
1632            let active_buffer = active_buffer.clone();
1633            async move {
1634                #[cfg(feature = "eval-support")]
1635                parsed_fut.await;
1636
1637                let index_state = if let Some(index_state) = index_state {
1638                    Some(index_state.lock_owned().await)
1639                } else {
1640                    None
1641                };
1642
1643                let cursor_offset = position.to_offset(&active_snapshot);
1644                let cursor_point = cursor_offset.to_point(&active_snapshot);
1645
1646                let before_retrieval = Instant::now();
1647
1648                let (diagnostic_groups, diagnostic_groups_truncated) =
1649                    Self::gather_nearby_diagnostics(
1650                        cursor_offset,
1651                        &diagnostics,
1652                        &active_snapshot,
1653                        options.max_diagnostic_bytes,
1654                    );
1655
1656                let cloud_request = match options.context {
1657                    ContextMode::Agentic(context_options) => {
1658                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1659                            cursor_point,
1660                            &active_snapshot,
1661                            &context_options.excerpt,
1662                            index_state.as_deref(),
1663                        ) else {
1664                            return Ok((None, None));
1665                        };
1666
1667                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1668                            ..active_snapshot.anchor_before(excerpt.range.end);
1669
1670                        if let Some(buffer_ix) =
1671                            included_files.iter().position(|(_, snapshot, _, _)| {
1672                                snapshot.remote_id() == active_snapshot.remote_id()
1673                            })
1674                        {
1675                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1676                            ranges.push(excerpt_anchor_range);
1677                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1678                            let last_ix = included_files.len() - 1;
1679                            included_files.swap(buffer_ix, last_ix);
1680                        } else {
1681                            included_files.push((
1682                                active_buffer.clone(),
1683                                active_snapshot.clone(),
1684                                excerpt_path.clone(),
1685                                vec![excerpt_anchor_range],
1686                            ));
1687                        }
1688
1689                        let included_files = included_files
1690                            .iter()
1691                            .map(|(_, snapshot, path, ranges)| {
1692                                let ranges = ranges
1693                                    .iter()
1694                                    .map(|range| {
1695                                        let point_range = range.to_point(&snapshot);
1696                                        Line(point_range.start.row)..Line(point_range.end.row)
1697                                    })
1698                                    .collect::<Vec<_>>();
1699                                let excerpts = assemble_excerpts(&snapshot, ranges);
1700                                predict_edits_v3::IncludedFile {
1701                                    path: path.clone(),
1702                                    max_row: Line(snapshot.max_point().row),
1703                                    excerpts,
1704                                }
1705                            })
1706                            .collect::<Vec<_>>();
1707
1708                        predict_edits_v3::PredictEditsRequest {
1709                            excerpt_path,
1710                            excerpt: String::new(),
1711                            excerpt_line_range: Line(0)..Line(0),
1712                            excerpt_range: 0..0,
1713                            cursor_point: predict_edits_v3::Point {
1714                                line: predict_edits_v3::Line(cursor_point.row),
1715                                column: cursor_point.column,
1716                            },
1717                            included_files,
1718                            referenced_declarations: vec![],
1719                            events,
1720                            can_collect_data,
1721                            diagnostic_groups,
1722                            diagnostic_groups_truncated,
1723                            debug_info: debug_tx.is_some(),
1724                            prompt_max_bytes: Some(options.max_prompt_bytes),
1725                            prompt_format: options.prompt_format,
1726                            // TODO [zeta2]
1727                            signatures: vec![],
1728                            excerpt_parent: None,
1729                            git_info: None,
1730                        }
1731                    }
1732                    ContextMode::Syntax(context_options) => {
1733                        let Some(context) = EditPredictionContext::gather_context(
1734                            cursor_point,
1735                            &active_snapshot,
1736                            parent_abs_path.as_deref(),
1737                            &context_options,
1738                            index_state.as_deref(),
1739                        ) else {
1740                            return Ok((None, None));
1741                        };
1742
1743                        make_syntax_context_cloud_request(
1744                            excerpt_path,
1745                            context,
1746                            events,
1747                            can_collect_data,
1748                            diagnostic_groups,
1749                            diagnostic_groups_truncated,
1750                            None,
1751                            debug_tx.is_some(),
1752                            &worktree_snapshots,
1753                            index_state.as_deref(),
1754                            Some(options.max_prompt_bytes),
1755                            options.prompt_format,
1756                        )
1757                    }
1758                };
1759
1760                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1761
1762                let inputs = EditPredictionInputs {
1763                    included_files: cloud_request.included_files,
1764                    events: cloud_request.events,
1765                    cursor_point: cloud_request.cursor_point,
1766                    cursor_path: cloud_request.excerpt_path,
1767                };
1768
1769                let retrieval_time = Instant::now() - before_retrieval;
1770
1771                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1772                    let (response_tx, response_rx) = oneshot::channel();
1773
1774                    debug_tx
1775                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1776                            ZetaEditPredictionDebugInfo {
1777                                inputs: inputs.clone(),
1778                                retrieval_time,
1779                                buffer: active_buffer.downgrade(),
1780                                local_prompt: match prompt_result.as_ref() {
1781                                    Ok((prompt, _)) => Ok(prompt.clone()),
1782                                    Err(err) => Err(err.to_string()),
1783                                },
1784                                position,
1785                                response_rx,
1786                            },
1787                        ))
1788                        .ok();
1789                    Some(response_tx)
1790                } else {
1791                    None
1792                };
1793
1794                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1795                    if let Some(debug_response_tx) = debug_response_tx {
1796                        debug_response_tx
1797                            .send((Err("Request skipped".to_string()), Duration::ZERO))
1798                            .ok();
1799                    }
1800                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1801                }
1802
1803                let (prompt, _) = prompt_result?;
1804                let generation_params =
1805                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1806                let request = open_ai::Request {
1807                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1808                    messages: vec![open_ai::RequestMessage::User {
1809                        content: open_ai::MessageContent::Plain(prompt),
1810                    }],
1811                    stream: false,
1812                    max_completion_tokens: None,
1813                    stop: generation_params.stop.unwrap_or_default(),
1814                    temperature: generation_params.temperature.unwrap_or(0.7),
1815                    tool_choice: None,
1816                    parallel_tool_calls: None,
1817                    tools: vec![],
1818                    prompt_cache_key: None,
1819                    reasoning_effort: None,
1820                };
1821
1822                log::trace!("Sending edit prediction request");
1823
1824                let before_request = Instant::now();
1825                let response = Self::send_raw_llm_request(
1826                    request,
1827                    client,
1828                    llm_token,
1829                    app_version,
1830                    #[cfg(feature = "eval-support")]
1831                    eval_cache,
1832                    #[cfg(feature = "eval-support")]
1833                    EvalCacheEntryKind::Prediction,
1834                )
1835                .await;
1836                let received_response_at = Instant::now();
1837                let request_time = received_response_at - before_request;
1838
1839                log::trace!("Got edit prediction response");
1840
1841                if let Some(debug_response_tx) = debug_response_tx {
1842                    debug_response_tx
1843                        .send((
1844                            response
1845                                .as_ref()
1846                                .map_err(|err| err.to_string())
1847                                .map(|response| response.0.clone()),
1848                            request_time,
1849                        ))
1850                        .ok();
1851                }
1852
1853                let (res, usage) = response?;
1854                let request_id = EditPredictionId(res.id.clone().into());
1855                let Some(mut output_text) = text_from_response(res) else {
1856                    return Ok((None, usage));
1857                };
1858
1859                if output_text.contains(CURSOR_MARKER) {
1860                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1861                    output_text = output_text.replace(CURSOR_MARKER, "");
1862                }
1863
1864                let get_buffer_from_context = |path: &Path| {
1865                    included_files
1866                        .iter()
1867                        .find_map(|(_, buffer, probe_path, ranges)| {
1868                            if probe_path.as_ref() == path {
1869                                Some((buffer, ranges.as_slice()))
1870                            } else {
1871                                None
1872                            }
1873                        })
1874                };
1875
1876                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1877                    PromptFormat::NumLinesUniDiff => {
1878                        // TODO: Implement parsing of multi-file diffs
1879                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1880                    }
1881                    PromptFormat::Minimal
1882                    | PromptFormat::MinimalQwen
1883                    | PromptFormat::SeedCoder1120 => {
1884                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1885                            let edits = vec![];
1886                            (&active_snapshot, edits)
1887                        } else {
1888                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1889                        }
1890                    }
1891                    PromptFormat::OldTextNewText => {
1892                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1893                            .await?
1894                    }
1895                    _ => {
1896                        bail!("unsupported prompt format {}", options.prompt_format)
1897                    }
1898                };
1899
1900                let edited_buffer = included_files
1901                    .iter()
1902                    .find_map(|(buffer, snapshot, _, _)| {
1903                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1904                            Some(buffer.clone())
1905                        } else {
1906                            None
1907                        }
1908                    })
1909                    .context("Failed to find buffer in included_buffers")?;
1910
1911                anyhow::Ok((
1912                    Some((
1913                        request_id,
1914                        inputs,
1915                        edited_buffer,
1916                        edited_buffer_snapshot.clone(),
1917                        edits,
1918                        received_response_at,
1919                    )),
1920                    usage,
1921                ))
1922            }
1923        });
1924
1925        cx.spawn({
1926            async move |this, cx| {
1927                let Some((
1928                    id,
1929                    inputs,
1930                    edited_buffer,
1931                    edited_buffer_snapshot,
1932                    edits,
1933                    received_response_at,
1934                )) = Self::handle_api_response(&this, request_task.await, cx)?
1935                else {
1936                    return Ok(None);
1937                };
1938
1939                // TODO telemetry: duration, etc
1940                Ok(EditPrediction::new(
1941                    id,
1942                    &edited_buffer,
1943                    &edited_buffer_snapshot,
1944                    edits.into(),
1945                    buffer_snapshotted_at,
1946                    received_response_at,
1947                    inputs,
1948                    cx,
1949                )
1950                .await)
1951            }
1952        })
1953    }
1954
1955    async fn send_raw_llm_request(
1956        request: open_ai::Request,
1957        client: Arc<Client>,
1958        llm_token: LlmApiToken,
1959        app_version: Version,
1960        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1961        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1962    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1963        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1964            http_client::Url::parse(&predict_edits_url)?
1965        } else {
1966            client
1967                .http_client()
1968                .build_zed_llm_url("/predict_edits/raw", &[])?
1969        };
1970
1971        #[cfg(feature = "eval-support")]
1972        let cache_key = if let Some(cache) = eval_cache {
1973            use collections::FxHasher;
1974            use std::hash::{Hash, Hasher};
1975
1976            let mut hasher = FxHasher::default();
1977            url.hash(&mut hasher);
1978            let request_str = serde_json::to_string_pretty(&request)?;
1979            request_str.hash(&mut hasher);
1980            let hash = hasher.finish();
1981
1982            let key = (eval_cache_kind, hash);
1983            if let Some(response_str) = cache.read(key) {
1984                return Ok((serde_json::from_str(&response_str)?, None));
1985            }
1986
1987            Some((cache, request_str, key))
1988        } else {
1989            None
1990        };
1991
1992        let (response, usage) = Self::send_api_request(
1993            |builder| {
1994                let req = builder
1995                    .uri(url.as_ref())
1996                    .body(serde_json::to_string(&request)?.into());
1997                Ok(req?)
1998            },
1999            client,
2000            llm_token,
2001            app_version,
2002        )
2003        .await?;
2004
2005        #[cfg(feature = "eval-support")]
2006        if let Some((cache, request, key)) = cache_key {
2007            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
2008        }
2009
2010        Ok((response, usage))
2011    }
2012
2013    fn handle_api_response<T>(
2014        this: &WeakEntity<Self>,
2015        response: Result<(T, Option<EditPredictionUsage>)>,
2016        cx: &mut gpui::AsyncApp,
2017    ) -> Result<T> {
2018        match response {
2019            Ok((data, usage)) => {
2020                if let Some(usage) = usage {
2021                    this.update(cx, |this, cx| {
2022                        this.user_store.update(cx, |user_store, cx| {
2023                            user_store.update_edit_prediction_usage(usage, cx);
2024                        });
2025                    })
2026                    .ok();
2027                }
2028                Ok(data)
2029            }
2030            Err(err) => {
2031                if err.is::<ZedUpdateRequiredError>() {
2032                    cx.update(|cx| {
2033                        this.update(cx, |this, _cx| {
2034                            this.update_required = true;
2035                        })
2036                        .ok();
2037
2038                        let error_message: SharedString = err.to_string().into();
2039                        show_app_notification(
2040                            NotificationId::unique::<ZedUpdateRequiredError>(),
2041                            cx,
2042                            move |cx| {
2043                                cx.new(|cx| {
2044                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2045                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2046                                })
2047                            },
2048                        );
2049                    })
2050                    .ok();
2051                }
2052                Err(err)
2053            }
2054        }
2055    }
2056
2057    async fn send_api_request<Res>(
2058        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2059        client: Arc<Client>,
2060        llm_token: LlmApiToken,
2061        app_version: Version,
2062    ) -> Result<(Res, Option<EditPredictionUsage>)>
2063    where
2064        Res: DeserializeOwned,
2065    {
2066        let http_client = client.http_client();
2067        let mut token = llm_token.acquire(&client).await?;
2068        let mut did_retry = false;
2069
2070        loop {
2071            let request_builder = http_client::Request::builder().method(Method::POST);
2072
2073            let request = build(
2074                request_builder
2075                    .header("Content-Type", "application/json")
2076                    .header("Authorization", format!("Bearer {}", token))
2077                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2078            )?;
2079
2080            let mut response = http_client.send(request).await?;
2081
2082            if let Some(minimum_required_version) = response
2083                .headers()
2084                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2085                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2086            {
2087                anyhow::ensure!(
2088                    app_version >= minimum_required_version,
2089                    ZedUpdateRequiredError {
2090                        minimum_version: minimum_required_version
2091                    }
2092                );
2093            }
2094
2095            if response.status().is_success() {
2096                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2097
2098                let mut body = Vec::new();
2099                response.body_mut().read_to_end(&mut body).await?;
2100                return Ok((serde_json::from_slice(&body)?, usage));
2101            } else if !did_retry
2102                && response
2103                    .headers()
2104                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2105                    .is_some()
2106            {
2107                did_retry = true;
2108                token = llm_token.refresh(&client).await?;
2109            } else {
2110                let mut body = String::new();
2111                response.body_mut().read_to_string(&mut body).await?;
2112                anyhow::bail!(
2113                    "Request failed with status: {:?}\nBody: {}",
2114                    response.status(),
2115                    body
2116                );
2117            }
2118        }
2119    }
2120
2121    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2122    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2123
2124    // Refresh the related excerpts when the user just beguns editing after
2125    // an idle period, and after they pause editing.
2126    fn refresh_context_if_needed(
2127        &mut self,
2128        project: &Entity<Project>,
2129        buffer: &Entity<language::Buffer>,
2130        cursor_position: language::Anchor,
2131        cx: &mut Context<Self>,
2132    ) {
2133        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2134            return;
2135        }
2136
2137        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2138            return;
2139        };
2140
2141        let now = Instant::now();
2142        let was_idle = zeta_project
2143            .refresh_context_timestamp
2144            .map_or(true, |timestamp| {
2145                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2146            });
2147        zeta_project.refresh_context_timestamp = Some(now);
2148        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2149            let buffer = buffer.clone();
2150            let project = project.clone();
2151            async move |this, cx| {
2152                if was_idle {
2153                    log::debug!("refetching edit prediction context after idle");
2154                } else {
2155                    cx.background_executor()
2156                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2157                        .await;
2158                    log::debug!("refetching edit prediction context after pause");
2159                }
2160                this.update(cx, |this, cx| {
2161                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2162
2163                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2164                        zeta_project.refresh_context_task = Some(task.log_err());
2165                    };
2166                })
2167                .ok()
2168            }
2169        }));
2170    }
2171
2172    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2173    // and avoid spawning more than one concurrent task.
2174    pub fn refresh_context(
2175        &mut self,
2176        project: Entity<Project>,
2177        buffer: Entity<language::Buffer>,
2178        cursor_position: language::Anchor,
2179        cx: &mut Context<Self>,
2180    ) -> Task<Result<()>> {
2181        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2182            return Task::ready(anyhow::Ok(()));
2183        };
2184
2185        let ContextMode::Agentic(options) = &self.options().context else {
2186            return Task::ready(anyhow::Ok(()));
2187        };
2188
2189        let snapshot = buffer.read(cx).snapshot();
2190        let cursor_point = cursor_position.to_point(&snapshot);
2191        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2192            cursor_point,
2193            &snapshot,
2194            &options.excerpt,
2195            None,
2196        ) else {
2197            return Task::ready(Ok(()));
2198        };
2199
2200        let app_version = AppVersion::global(cx);
2201        let client = self.client.clone();
2202        let llm_token = self.llm_token.clone();
2203        let debug_tx = self.debug_tx.clone();
2204        let current_file_path: Arc<Path> = snapshot
2205            .file()
2206            .map(|f| f.full_path(cx).into())
2207            .unwrap_or_else(|| Path::new("untitled").into());
2208
2209        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2210            predict_edits_v3::PlanContextRetrievalRequest {
2211                excerpt: cursor_excerpt.text(&snapshot).body,
2212                excerpt_path: current_file_path,
2213                excerpt_line_range: cursor_excerpt.line_range,
2214                cursor_file_max_row: Line(snapshot.max_point().row),
2215                events: zeta_project.events(cx),
2216            },
2217        ) {
2218            Ok(prompt) => prompt,
2219            Err(err) => {
2220                return Task::ready(Err(err));
2221            }
2222        };
2223
2224        if let Some(debug_tx) = &debug_tx {
2225            debug_tx
2226                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2227                    ZetaContextRetrievalStartedDebugInfo {
2228                        project: project.clone(),
2229                        timestamp: Instant::now(),
2230                        search_prompt: prompt.clone(),
2231                    },
2232                ))
2233                .ok();
2234        }
2235
2236        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2237            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2238                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2239            );
2240
2241            let description = schema
2242                .get("description")
2243                .and_then(|description| description.as_str())
2244                .unwrap()
2245                .to_string();
2246
2247            (schema.into(), description)
2248        });
2249
2250        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2251
2252        let request = open_ai::Request {
2253            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2254            messages: vec![open_ai::RequestMessage::User {
2255                content: open_ai::MessageContent::Plain(prompt),
2256            }],
2257            stream: false,
2258            max_completion_tokens: None,
2259            stop: Default::default(),
2260            temperature: 0.7,
2261            tool_choice: None,
2262            parallel_tool_calls: None,
2263            tools: vec![open_ai::ToolDefinition::Function {
2264                function: FunctionDefinition {
2265                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2266                    description: Some(tool_description),
2267                    parameters: Some(tool_schema),
2268                },
2269            }],
2270            prompt_cache_key: None,
2271            reasoning_effort: None,
2272        };
2273
2274        #[cfg(feature = "eval-support")]
2275        let eval_cache = self.eval_cache.clone();
2276
2277        cx.spawn(async move |this, cx| {
2278            log::trace!("Sending search planning request");
2279            let response = Self::send_raw_llm_request(
2280                request,
2281                client,
2282                llm_token,
2283                app_version,
2284                #[cfg(feature = "eval-support")]
2285                eval_cache.clone(),
2286                #[cfg(feature = "eval-support")]
2287                EvalCacheEntryKind::Context,
2288            )
2289            .await;
2290            let mut response = Self::handle_api_response(&this, response, cx)?;
2291            log::trace!("Got search planning response");
2292
2293            let choice = response
2294                .choices
2295                .pop()
2296                .context("No choices in retrieval response")?;
2297            let open_ai::RequestMessage::Assistant {
2298                content: _,
2299                tool_calls,
2300            } = choice.message
2301            else {
2302                anyhow::bail!("Retrieval response didn't include an assistant message");
2303            };
2304
2305            let mut queries: Vec<SearchToolQuery> = Vec::new();
2306            for tool_call in tool_calls {
2307                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2308                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2309                    log::warn!(
2310                        "Context retrieval response tried to call an unknown tool: {}",
2311                        function.name
2312                    );
2313
2314                    continue;
2315                }
2316
2317                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2318                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2319                queries.extend(input.queries);
2320            }
2321
2322            if let Some(debug_tx) = &debug_tx {
2323                debug_tx
2324                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2325                        ZetaSearchQueryDebugInfo {
2326                            project: project.clone(),
2327                            timestamp: Instant::now(),
2328                            search_queries: queries.clone(),
2329                        },
2330                    ))
2331                    .ok();
2332            }
2333
2334            log::trace!("Running retrieval search: {queries:#?}");
2335
2336            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2337                queries,
2338                project.clone(),
2339                #[cfg(feature = "eval-support")]
2340                eval_cache,
2341                cx,
2342            )
2343            .await;
2344
2345            log::trace!("Search queries executed");
2346
2347            if let Some(debug_tx) = &debug_tx {
2348                debug_tx
2349                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2350                        ZetaContextRetrievalDebugInfo {
2351                            project: project.clone(),
2352                            timestamp: Instant::now(),
2353                        },
2354                    ))
2355                    .ok();
2356            }
2357
2358            this.update(cx, |this, _cx| {
2359                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2360                    return Ok(());
2361                };
2362                zeta_project.refresh_context_task.take();
2363                if let Some(debug_tx) = &this.debug_tx {
2364                    debug_tx
2365                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2366                            ZetaContextRetrievalDebugInfo {
2367                                project,
2368                                timestamp: Instant::now(),
2369                            },
2370                        ))
2371                        .ok();
2372                }
2373                match related_excerpts_result {
2374                    Ok(excerpts) => {
2375                        zeta_project.context = Some(excerpts);
2376                        Ok(())
2377                    }
2378                    Err(error) => Err(error),
2379                }
2380            })?
2381        })
2382    }
2383
2384    pub fn set_context(
2385        &mut self,
2386        project: Entity<Project>,
2387        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2388    ) {
2389        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2390            zeta_project.context = Some(context);
2391        }
2392    }
2393
2394    fn gather_nearby_diagnostics(
2395        cursor_offset: usize,
2396        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2397        snapshot: &BufferSnapshot,
2398        max_diagnostics_bytes: usize,
2399    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2400        // TODO: Could make this more efficient
2401        let mut diagnostic_groups = Vec::new();
2402        for (language_server_id, diagnostics) in diagnostic_sets {
2403            let mut groups = Vec::new();
2404            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2405            diagnostic_groups.extend(
2406                groups
2407                    .into_iter()
2408                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2409            );
2410        }
2411
2412        // sort by proximity to cursor
2413        diagnostic_groups.sort_by_key(|group| {
2414            let range = &group.entries[group.primary_ix].range;
2415            if range.start >= cursor_offset {
2416                range.start - cursor_offset
2417            } else if cursor_offset >= range.end {
2418                cursor_offset - range.end
2419            } else {
2420                (cursor_offset - range.start).min(range.end - cursor_offset)
2421            }
2422        });
2423
2424        let mut results = Vec::new();
2425        let mut diagnostic_groups_truncated = false;
2426        let mut diagnostics_byte_count = 0;
2427        for group in diagnostic_groups {
2428            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2429            diagnostics_byte_count += raw_value.get().len();
2430            if diagnostics_byte_count > max_diagnostics_bytes {
2431                diagnostic_groups_truncated = true;
2432                break;
2433            }
2434            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2435        }
2436
2437        (results, diagnostic_groups_truncated)
2438    }
2439
2440    // TODO: Dedupe with similar code in request_prediction?
2441    pub fn cloud_request_for_zeta_cli(
2442        &mut self,
2443        project: &Entity<Project>,
2444        buffer: &Entity<Buffer>,
2445        position: language::Anchor,
2446        cx: &mut Context<Self>,
2447    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2448        let project_state = self.projects.get(&project.entity_id());
2449
2450        let index_state = project_state.and_then(|state| {
2451            state
2452                .syntax_index
2453                .as_ref()
2454                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2455        });
2456        let options = self.options.clone();
2457        let snapshot = buffer.read(cx).snapshot();
2458        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2459            return Task::ready(Err(anyhow!("No file path for excerpt")));
2460        };
2461        let worktree_snapshots = project
2462            .read(cx)
2463            .worktrees(cx)
2464            .map(|worktree| worktree.read(cx).snapshot())
2465            .collect::<Vec<_>>();
2466
2467        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2468            let mut path = f.worktree.read(cx).absolutize(&f.path);
2469            if path.pop() { Some(path) } else { None }
2470        });
2471
2472        cx.background_spawn(async move {
2473            let index_state = if let Some(index_state) = index_state {
2474                Some(index_state.lock_owned().await)
2475            } else {
2476                None
2477            };
2478
2479            let cursor_point = position.to_point(&snapshot);
2480
2481            let debug_info = true;
2482            EditPredictionContext::gather_context(
2483                cursor_point,
2484                &snapshot,
2485                parent_abs_path.as_deref(),
2486                match &options.context {
2487                    ContextMode::Agentic(_) => {
2488                        // TODO
2489                        panic!("Llm mode not supported in zeta cli yet");
2490                    }
2491                    ContextMode::Syntax(edit_prediction_context_options) => {
2492                        edit_prediction_context_options
2493                    }
2494                },
2495                index_state.as_deref(),
2496            )
2497            .context("Failed to select excerpt")
2498            .map(|context| {
2499                make_syntax_context_cloud_request(
2500                    excerpt_path.into(),
2501                    context,
2502                    // TODO pass everything
2503                    Vec::new(),
2504                    false,
2505                    Vec::new(),
2506                    false,
2507                    None,
2508                    debug_info,
2509                    &worktree_snapshots,
2510                    index_state.as_deref(),
2511                    Some(options.max_prompt_bytes),
2512                    options.prompt_format,
2513                )
2514            })
2515        })
2516    }
2517
2518    pub fn wait_for_initial_indexing(
2519        &mut self,
2520        project: &Entity<Project>,
2521        cx: &mut Context<Self>,
2522    ) -> Task<Result<()>> {
2523        let zeta_project = self.get_or_init_zeta_project(project, cx);
2524        if let Some(syntax_index) = &zeta_project.syntax_index {
2525            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2526        } else {
2527            Task::ready(Ok(()))
2528        }
2529    }
2530
2531    fn is_file_open_source(
2532        &self,
2533        project: &Entity<Project>,
2534        file: &Arc<dyn File>,
2535        cx: &App,
2536    ) -> bool {
2537        if !file.is_local() || file.is_private() {
2538            return false;
2539        }
2540        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2541            return false;
2542        };
2543        zeta_project
2544            .license_detection_watchers
2545            .get(&file.worktree_id(cx))
2546            .as_ref()
2547            .is_some_and(|watcher| watcher.is_project_open_source())
2548    }
2549
2550    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2551        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2552    }
2553
2554    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2555        if !self.data_collection_choice.is_enabled() {
2556            return false;
2557        }
2558        events.iter().all(|event| {
2559            matches!(
2560                event.as_ref(),
2561                Event::BufferChange {
2562                    in_open_source_repo: true,
2563                    ..
2564                }
2565            )
2566        })
2567    }
2568
2569    fn load_data_collection_choice() -> DataCollectionChoice {
2570        let choice = KEY_VALUE_STORE
2571            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2572            .log_err()
2573            .flatten();
2574
2575        match choice.as_deref() {
2576            Some("true") => DataCollectionChoice::Enabled,
2577            Some("false") => DataCollectionChoice::Disabled,
2578            Some(_) => {
2579                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2580                DataCollectionChoice::NotAnswered
2581            }
2582            None => DataCollectionChoice::NotAnswered,
2583        }
2584    }
2585
2586    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2587        self.shown_predictions.iter()
2588    }
2589
2590    pub fn shown_completions_len(&self) -> usize {
2591        self.shown_predictions.len()
2592    }
2593
2594    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2595        self.rated_predictions.contains(id)
2596    }
2597
2598    pub fn rate_prediction(
2599        &mut self,
2600        prediction: &EditPrediction,
2601        rating: EditPredictionRating,
2602        feedback: String,
2603        cx: &mut Context<Self>,
2604    ) {
2605        self.rated_predictions.insert(prediction.id.clone());
2606        telemetry::event!(
2607            "Edit Prediction Rated",
2608            rating,
2609            inputs = prediction.inputs,
2610            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2611            feedback
2612        );
2613        self.client.telemetry().flush_events().detach();
2614        cx.notify();
2615    }
2616}
2617
2618pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2619    let choice = res.choices.pop()?;
2620    let output_text = match choice.message {
2621        open_ai::RequestMessage::Assistant {
2622            content: Some(open_ai::MessageContent::Plain(content)),
2623            ..
2624        } => content,
2625        open_ai::RequestMessage::Assistant {
2626            content: Some(open_ai::MessageContent::Multipart(mut content)),
2627            ..
2628        } => {
2629            if content.is_empty() {
2630                log::error!("No output from Baseten completion response");
2631                return None;
2632            }
2633
2634            match content.remove(0) {
2635                open_ai::MessagePart::Text { text } => text,
2636                open_ai::MessagePart::Image { .. } => {
2637                    log::error!("Expected text, got an image");
2638                    return None;
2639                }
2640            }
2641        }
2642        _ => {
2643            log::error!("Invalid response message: {:?}", choice.message);
2644            return None;
2645        }
2646    };
2647    Some(output_text)
2648}
2649
2650#[derive(Error, Debug)]
2651#[error(
2652    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2653)]
2654pub struct ZedUpdateRequiredError {
2655    minimum_version: Version,
2656}
2657
2658fn make_syntax_context_cloud_request(
2659    excerpt_path: Arc<Path>,
2660    context: EditPredictionContext,
2661    events: Vec<Arc<predict_edits_v3::Event>>,
2662    can_collect_data: bool,
2663    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2664    diagnostic_groups_truncated: bool,
2665    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2666    debug_info: bool,
2667    worktrees: &Vec<worktree::Snapshot>,
2668    index_state: Option<&SyntaxIndexState>,
2669    prompt_max_bytes: Option<usize>,
2670    prompt_format: PromptFormat,
2671) -> predict_edits_v3::PredictEditsRequest {
2672    let mut signatures = Vec::new();
2673    let mut declaration_to_signature_index = HashMap::default();
2674    let mut referenced_declarations = Vec::new();
2675
2676    for snippet in context.declarations {
2677        let project_entry_id = snippet.declaration.project_entry_id();
2678        let Some(path) = worktrees.iter().find_map(|worktree| {
2679            worktree.entry_for_id(project_entry_id).map(|entry| {
2680                let mut full_path = RelPathBuf::new();
2681                full_path.push(worktree.root_name());
2682                full_path.push(&entry.path);
2683                full_path
2684            })
2685        }) else {
2686            continue;
2687        };
2688
2689        let parent_index = index_state.and_then(|index_state| {
2690            snippet.declaration.parent().and_then(|parent| {
2691                add_signature(
2692                    parent,
2693                    &mut declaration_to_signature_index,
2694                    &mut signatures,
2695                    index_state,
2696                )
2697            })
2698        });
2699
2700        let (text, text_is_truncated) = snippet.declaration.item_text();
2701        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2702            path: path.as_std_path().into(),
2703            text: text.into(),
2704            range: snippet.declaration.item_line_range(),
2705            text_is_truncated,
2706            signature_range: snippet.declaration.signature_range_in_item_text(),
2707            parent_index,
2708            signature_score: snippet.score(DeclarationStyle::Signature),
2709            declaration_score: snippet.score(DeclarationStyle::Declaration),
2710            score_components: snippet.components,
2711        });
2712    }
2713
2714    let excerpt_parent = index_state.and_then(|index_state| {
2715        context
2716            .excerpt
2717            .parent_declarations
2718            .last()
2719            .and_then(|(parent, _)| {
2720                add_signature(
2721                    *parent,
2722                    &mut declaration_to_signature_index,
2723                    &mut signatures,
2724                    index_state,
2725                )
2726            })
2727    });
2728
2729    predict_edits_v3::PredictEditsRequest {
2730        excerpt_path,
2731        excerpt: context.excerpt_text.body,
2732        excerpt_line_range: context.excerpt.line_range,
2733        excerpt_range: context.excerpt.range,
2734        cursor_point: predict_edits_v3::Point {
2735            line: predict_edits_v3::Line(context.cursor_point.row),
2736            column: context.cursor_point.column,
2737        },
2738        referenced_declarations,
2739        included_files: vec![],
2740        signatures,
2741        excerpt_parent,
2742        events,
2743        can_collect_data,
2744        diagnostic_groups,
2745        diagnostic_groups_truncated,
2746        git_info,
2747        debug_info,
2748        prompt_max_bytes,
2749        prompt_format,
2750    }
2751}
2752
2753fn add_signature(
2754    declaration_id: DeclarationId,
2755    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2756    signatures: &mut Vec<Signature>,
2757    index: &SyntaxIndexState,
2758) -> Option<usize> {
2759    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2760        return Some(*signature_index);
2761    }
2762    let Some(parent_declaration) = index.declaration(declaration_id) else {
2763        log::error!("bug: missing parent declaration");
2764        return None;
2765    };
2766    let parent_index = parent_declaration.parent().and_then(|parent| {
2767        add_signature(parent, declaration_to_signature_index, signatures, index)
2768    });
2769    let (text, text_is_truncated) = parent_declaration.signature_text();
2770    let signature_index = signatures.len();
2771    signatures.push(Signature {
2772        text: text.into(),
2773        text_is_truncated,
2774        parent_index,
2775        range: parent_declaration.signature_line_range(),
2776    });
2777    declaration_to_signature_index.insert(declaration_id, signature_index);
2778    Some(signature_index)
2779}
2780
2781#[cfg(feature = "eval-support")]
2782pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2783
2784#[cfg(feature = "eval-support")]
2785#[derive(Debug, Clone, Copy, PartialEq)]
2786pub enum EvalCacheEntryKind {
2787    Context,
2788    Search,
2789    Prediction,
2790}
2791
2792#[cfg(feature = "eval-support")]
2793impl std::fmt::Display for EvalCacheEntryKind {
2794    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2795        match self {
2796            EvalCacheEntryKind::Search => write!(f, "search"),
2797            EvalCacheEntryKind::Context => write!(f, "context"),
2798            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2799        }
2800    }
2801}
2802
2803#[cfg(feature = "eval-support")]
2804pub trait EvalCache: Send + Sync {
2805    fn read(&self, key: EvalCacheKey) -> Option<String>;
2806    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2807}
2808
2809#[derive(Debug, Clone, Copy)]
2810pub enum DataCollectionChoice {
2811    NotAnswered,
2812    Enabled,
2813    Disabled,
2814}
2815
2816impl DataCollectionChoice {
2817    pub fn is_enabled(self) -> bool {
2818        match self {
2819            Self::Enabled => true,
2820            Self::NotAnswered | Self::Disabled => false,
2821        }
2822    }
2823
2824    pub fn is_answered(self) -> bool {
2825        match self {
2826            Self::Enabled | Self::Disabled => true,
2827            Self::NotAnswered => false,
2828        }
2829    }
2830
2831    #[must_use]
2832    pub fn toggle(&self) -> DataCollectionChoice {
2833        match self {
2834            Self::Enabled => Self::Disabled,
2835            Self::Disabled => Self::Enabled,
2836            Self::NotAnswered => Self::Enabled,
2837        }
2838    }
2839}
2840
2841impl From<bool> for DataCollectionChoice {
2842    fn from(value: bool) -> Self {
2843        match value {
2844            true => DataCollectionChoice::Enabled,
2845            false => DataCollectionChoice::Disabled,
2846        }
2847    }
2848}
2849
2850struct ZedPredictUpsell;
2851
2852impl Dismissable for ZedPredictUpsell {
2853    const KEY: &'static str = "dismissed-edit-predict-upsell";
2854
2855    fn dismissed() -> bool {
2856        // To make this backwards compatible with older versions of Zed, we
2857        // check if the user has seen the previous Edit Prediction Onboarding
2858        // before, by checking the data collection choice which was written to
2859        // the database once the user clicked on "Accept and Enable"
2860        if KEY_VALUE_STORE
2861            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2862            .log_err()
2863            .is_some_and(|s| s.is_some())
2864        {
2865            return true;
2866        }
2867
2868        KEY_VALUE_STORE
2869            .read_kvp(Self::KEY)
2870            .log_err()
2871            .is_some_and(|s| s.is_some())
2872    }
2873}
2874
2875pub fn should_show_upsell_modal() -> bool {
2876    !ZedPredictUpsell::dismissed()
2877}
2878
2879pub fn init(cx: &mut App) {
2880    feature_gate_predict_edits_actions(cx);
2881
2882    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2883        workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2884            if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2885                RatePredictionsModal::toggle(workspace, window, cx);
2886            }
2887        });
2888
2889        workspace.register_action(
2890            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2891                ZedPredictModal::toggle(
2892                    workspace,
2893                    workspace.user_store().clone(),
2894                    workspace.client().clone(),
2895                    window,
2896                    cx,
2897                )
2898            },
2899        );
2900
2901        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2902            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2903                settings
2904                    .project
2905                    .all_languages
2906                    .features
2907                    .get_or_insert_default()
2908                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2909            });
2910        });
2911    })
2912    .detach();
2913}
2914
2915fn feature_gate_predict_edits_actions(cx: &mut App) {
2916    let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2917    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2918    let zeta_all_action_types = [
2919        TypeId::of::<RateCompletions>(),
2920        TypeId::of::<ResetOnboarding>(),
2921        zed_actions::OpenZedPredictOnboarding.type_id(),
2922        TypeId::of::<ClearHistory>(),
2923        TypeId::of::<ThumbsUpActivePrediction>(),
2924        TypeId::of::<ThumbsDownActivePrediction>(),
2925        TypeId::of::<NextEdit>(),
2926        TypeId::of::<PreviousEdit>(),
2927    ];
2928
2929    CommandPaletteFilter::update_global(cx, |filter, _cx| {
2930        filter.hide_action_types(&rate_completion_action_types);
2931        filter.hide_action_types(&reset_onboarding_action_types);
2932        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2933    });
2934
2935    cx.observe_global::<SettingsStore>(move |cx| {
2936        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2937        let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2938
2939        CommandPaletteFilter::update_global(cx, |filter, _cx| {
2940            if is_ai_disabled {
2941                filter.hide_action_types(&zeta_all_action_types);
2942            } else if has_feature_flag {
2943                filter.show_action_types(&rate_completion_action_types);
2944            } else {
2945                filter.hide_action_types(&rate_completion_action_types);
2946            }
2947        });
2948    })
2949    .detach();
2950
2951    cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2952        if !DisableAiSettings::get_global(cx).disable_ai {
2953            if is_enabled {
2954                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2955                    filter.show_action_types(&rate_completion_action_types);
2956                });
2957            } else {
2958                CommandPaletteFilter::update_global(cx, |filter, _cx| {
2959                    filter.hide_action_types(&rate_completion_action_types);
2960                });
2961            }
2962        }
2963    })
2964    .detach();
2965}
2966
2967#[cfg(test)]
2968mod tests {
2969    use std::{path::Path, sync::Arc};
2970
2971    use client::UserStore;
2972    use clock::FakeSystemClock;
2973    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2974    use futures::{
2975        AsyncReadExt, StreamExt,
2976        channel::{mpsc, oneshot},
2977    };
2978    use gpui::{
2979        Entity, TestAppContext,
2980        http_client::{FakeHttpClient, Response},
2981        prelude::*,
2982    };
2983    use indoc::indoc;
2984    use language::OffsetRangeExt as _;
2985    use open_ai::Usage;
2986    use pretty_assertions::{assert_eq, assert_matches};
2987    use project::{FakeFs, Project};
2988    use serde_json::json;
2989    use settings::SettingsStore;
2990    use util::path;
2991    use uuid::Uuid;
2992
2993    use crate::{BufferEditPrediction, Zeta};
2994
2995    #[gpui::test]
2996    async fn test_current_state(cx: &mut TestAppContext) {
2997        let (zeta, mut req_rx) = init_test(cx);
2998        let fs = FakeFs::new(cx.executor());
2999        fs.insert_tree(
3000            "/root",
3001            json!({
3002                "1.txt": "Hello!\nHow\nBye\n",
3003                "2.txt": "Hola!\nComo\nAdios\n"
3004            }),
3005        )
3006        .await;
3007        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3008
3009        zeta.update(cx, |zeta, cx| {
3010            zeta.register_project(&project, cx);
3011        });
3012
3013        let buffer1 = project
3014            .update(cx, |project, cx| {
3015                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
3016                project.open_buffer(path, cx)
3017            })
3018            .await
3019            .unwrap();
3020        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
3021        let position = snapshot1.anchor_before(language::Point::new(1, 3));
3022
3023        // Prediction for current file
3024
3025        zeta.update(cx, |zeta, cx| {
3026            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3027        });
3028        let (_request, respond_tx) = req_rx.next().await.unwrap();
3029
3030        respond_tx
3031            .send(model_response(indoc! {r"
3032                --- a/root/1.txt
3033                +++ b/root/1.txt
3034                @@ ... @@
3035                 Hello!
3036                -How
3037                +How are you?
3038                 Bye
3039            "}))
3040            .unwrap();
3041
3042        cx.run_until_parked();
3043
3044        zeta.read_with(cx, |zeta, cx| {
3045            let prediction = zeta
3046                .current_prediction_for_buffer(&buffer1, &project, cx)
3047                .unwrap();
3048            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3049        });
3050
3051        // Context refresh
3052        let refresh_task = zeta.update(cx, |zeta, cx| {
3053            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
3054        });
3055        let (_request, respond_tx) = req_rx.next().await.unwrap();
3056        respond_tx
3057            .send(open_ai::Response {
3058                id: Uuid::new_v4().to_string(),
3059                object: "response".into(),
3060                created: 0,
3061                model: "model".into(),
3062                choices: vec![open_ai::Choice {
3063                    index: 0,
3064                    message: open_ai::RequestMessage::Assistant {
3065                        content: None,
3066                        tool_calls: vec![open_ai::ToolCall {
3067                            id: "search".into(),
3068                            content: open_ai::ToolCallContent::Function {
3069                                function: open_ai::FunctionContent {
3070                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3071                                        .to_string(),
3072                                    arguments: serde_json::to_string(&SearchToolInput {
3073                                        queries: Box::new([SearchToolQuery {
3074                                            glob: "root/2.txt".to_string(),
3075                                            syntax_node: vec![],
3076                                            content: Some(".".into()),
3077                                        }]),
3078                                    })
3079                                    .unwrap(),
3080                                },
3081                            },
3082                        }],
3083                    },
3084                    finish_reason: None,
3085                }],
3086                usage: Usage {
3087                    prompt_tokens: 0,
3088                    completion_tokens: 0,
3089                    total_tokens: 0,
3090                },
3091            })
3092            .unwrap();
3093        refresh_task.await.unwrap();
3094
3095        zeta.update(cx, |zeta, cx| {
3096            zeta.discard_current_prediction(&project, cx);
3097        });
3098
3099        // Prediction for another file
3100        zeta.update(cx, |zeta, cx| {
3101            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3102        });
3103        let (_request, respond_tx) = req_rx.next().await.unwrap();
3104        respond_tx
3105            .send(model_response(indoc! {r#"
3106                --- a/root/2.txt
3107                +++ b/root/2.txt
3108                 Hola!
3109                -Como
3110                +Como estas?
3111                 Adios
3112            "#}))
3113            .unwrap();
3114        cx.run_until_parked();
3115
3116        zeta.read_with(cx, |zeta, cx| {
3117            let prediction = zeta
3118                .current_prediction_for_buffer(&buffer1, &project, cx)
3119                .unwrap();
3120            assert_matches!(
3121                prediction,
3122                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3123            );
3124        });
3125
3126        let buffer2 = project
3127            .update(cx, |project, cx| {
3128                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3129                project.open_buffer(path, cx)
3130            })
3131            .await
3132            .unwrap();
3133
3134        zeta.read_with(cx, |zeta, cx| {
3135            let prediction = zeta
3136                .current_prediction_for_buffer(&buffer2, &project, cx)
3137                .unwrap();
3138            assert_matches!(prediction, BufferEditPrediction::Local { .. });
3139        });
3140    }
3141
3142    #[gpui::test]
3143    async fn test_simple_request(cx: &mut TestAppContext) {
3144        let (zeta, mut req_rx) = init_test(cx);
3145        let fs = FakeFs::new(cx.executor());
3146        fs.insert_tree(
3147            "/root",
3148            json!({
3149                "foo.md":  "Hello!\nHow\nBye\n"
3150            }),
3151        )
3152        .await;
3153        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3154
3155        let buffer = project
3156            .update(cx, |project, cx| {
3157                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3158                project.open_buffer(path, cx)
3159            })
3160            .await
3161            .unwrap();
3162        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3163        let position = snapshot.anchor_before(language::Point::new(1, 3));
3164
3165        let prediction_task = zeta.update(cx, |zeta, cx| {
3166            zeta.request_prediction(&project, &buffer, position, cx)
3167        });
3168
3169        let (_, respond_tx) = req_rx.next().await.unwrap();
3170
3171        // TODO Put back when we have a structured request again
3172        // assert_eq!(
3173        //     request.excerpt_path.as_ref(),
3174        //     Path::new(path!("root/foo.md"))
3175        // );
3176        // assert_eq!(
3177        //     request.cursor_point,
3178        //     Point {
3179        //         line: Line(1),
3180        //         column: 3
3181        //     }
3182        // );
3183
3184        respond_tx
3185            .send(model_response(indoc! { r"
3186                --- a/root/foo.md
3187                +++ b/root/foo.md
3188                @@ ... @@
3189                 Hello!
3190                -How
3191                +How are you?
3192                 Bye
3193            "}))
3194            .unwrap();
3195
3196        let prediction = prediction_task.await.unwrap().unwrap();
3197
3198        assert_eq!(prediction.edits.len(), 1);
3199        assert_eq!(
3200            prediction.edits[0].0.to_point(&snapshot).start,
3201            language::Point::new(1, 3)
3202        );
3203        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3204    }
3205
3206    #[gpui::test]
3207    async fn test_request_events(cx: &mut TestAppContext) {
3208        let (zeta, mut req_rx) = init_test(cx);
3209        let fs = FakeFs::new(cx.executor());
3210        fs.insert_tree(
3211            "/root",
3212            json!({
3213                "foo.md": "Hello!\n\nBye\n"
3214            }),
3215        )
3216        .await;
3217        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3218
3219        let buffer = project
3220            .update(cx, |project, cx| {
3221                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3222                project.open_buffer(path, cx)
3223            })
3224            .await
3225            .unwrap();
3226
3227        zeta.update(cx, |zeta, cx| {
3228            zeta.register_buffer(&buffer, &project, cx);
3229        });
3230
3231        buffer.update(cx, |buffer, cx| {
3232            buffer.edit(vec![(7..7, "How")], None, cx);
3233        });
3234
3235        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3236        let position = snapshot.anchor_before(language::Point::new(1, 3));
3237
3238        let prediction_task = zeta.update(cx, |zeta, cx| {
3239            zeta.request_prediction(&project, &buffer, position, cx)
3240        });
3241
3242        let (request, respond_tx) = req_rx.next().await.unwrap();
3243
3244        let prompt = prompt_from_request(&request);
3245        assert!(
3246            prompt.contains(indoc! {"
3247            --- a/root/foo.md
3248            +++ b/root/foo.md
3249            @@ -1,3 +1,3 @@
3250             Hello!
3251            -
3252            +How
3253             Bye
3254        "}),
3255            "{prompt}"
3256        );
3257
3258        respond_tx
3259            .send(model_response(indoc! {r#"
3260                --- a/root/foo.md
3261                +++ b/root/foo.md
3262                @@ ... @@
3263                 Hello!
3264                -How
3265                +How are you?
3266                 Bye
3267            "#}))
3268            .unwrap();
3269
3270        let prediction = prediction_task.await.unwrap().unwrap();
3271
3272        assert_eq!(prediction.edits.len(), 1);
3273        assert_eq!(
3274            prediction.edits[0].0.to_point(&snapshot).start,
3275            language::Point::new(1, 3)
3276        );
3277        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3278    }
3279
3280    // Skipped until we start including diagnostics in prompt
3281    // #[gpui::test]
3282    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3283    //     let (zeta, mut req_rx) = init_test(cx);
3284    //     let fs = FakeFs::new(cx.executor());
3285    //     fs.insert_tree(
3286    //         "/root",
3287    //         json!({
3288    //             "foo.md": "Hello!\nBye"
3289    //         }),
3290    //     )
3291    //     .await;
3292    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3293
3294    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3295    //     let diagnostic = lsp::Diagnostic {
3296    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3297    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
3298    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3299    //         ..Default::default()
3300    //     };
3301
3302    //     project.update(cx, |project, cx| {
3303    //         project.lsp_store().update(cx, |lsp_store, cx| {
3304    //             // Create some diagnostics
3305    //             lsp_store
3306    //                 .update_diagnostics(
3307    //                     LanguageServerId(0),
3308    //                     lsp::PublishDiagnosticsParams {
3309    //                         uri: path_to_buffer_uri.clone(),
3310    //                         diagnostics: vec![diagnostic],
3311    //                         version: None,
3312    //                     },
3313    //                     None,
3314    //                     language::DiagnosticSourceKind::Pushed,
3315    //                     &[],
3316    //                     cx,
3317    //                 )
3318    //                 .unwrap();
3319    //         });
3320    //     });
3321
3322    //     let buffer = project
3323    //         .update(cx, |project, cx| {
3324    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3325    //             project.open_buffer(path, cx)
3326    //         })
3327    //         .await
3328    //         .unwrap();
3329
3330    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3331    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
3332
3333    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
3334    //         zeta.request_prediction(&project, &buffer, position, cx)
3335    //     });
3336
3337    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
3338
3339    //     assert_eq!(request.diagnostic_groups.len(), 1);
3340    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3341    //         .unwrap();
3342    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3343    //     assert_eq!(
3344    //         value,
3345    //         json!({
3346    //             "entries": [{
3347    //                 "range": {
3348    //                     "start": 8,
3349    //                     "end": 10
3350    //                 },
3351    //                 "diagnostic": {
3352    //                     "source": null,
3353    //                     "code": null,
3354    //                     "code_description": null,
3355    //                     "severity": 1,
3356    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3357    //                     "markdown": null,
3358    //                     "group_id": 0,
3359    //                     "is_primary": true,
3360    //                     "is_disk_based": false,
3361    //                     "is_unnecessary": false,
3362    //                     "source_kind": "Pushed",
3363    //                     "data": null,
3364    //                     "underline": true
3365    //                 }
3366    //             }],
3367    //             "primary_ix": 0
3368    //         })
3369    //     );
3370    // }
3371
3372    fn model_response(text: &str) -> open_ai::Response {
3373        open_ai::Response {
3374            id: Uuid::new_v4().to_string(),
3375            object: "response".into(),
3376            created: 0,
3377            model: "model".into(),
3378            choices: vec![open_ai::Choice {
3379                index: 0,
3380                message: open_ai::RequestMessage::Assistant {
3381                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
3382                    tool_calls: vec![],
3383                },
3384                finish_reason: None,
3385            }],
3386            usage: Usage {
3387                prompt_tokens: 0,
3388                completion_tokens: 0,
3389                total_tokens: 0,
3390            },
3391        }
3392    }
3393
3394    fn prompt_from_request(request: &open_ai::Request) -> &str {
3395        assert_eq!(request.messages.len(), 1);
3396        let open_ai::RequestMessage::User {
3397            content: open_ai::MessageContent::Plain(content),
3398            ..
3399        } = &request.messages[0]
3400        else {
3401            panic!(
3402                "Request does not have single user message of type Plain. {:#?}",
3403                request
3404            );
3405        };
3406        content
3407    }
3408
3409    fn init_test(
3410        cx: &mut TestAppContext,
3411    ) -> (
3412        Entity<Zeta>,
3413        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3414    ) {
3415        cx.update(move |cx| {
3416            let settings_store = SettingsStore::test(cx);
3417            cx.set_global(settings_store);
3418            zlog::init_test();
3419
3420            let (req_tx, req_rx) = mpsc::unbounded();
3421
3422            let http_client = FakeHttpClient::create({
3423                move |req| {
3424                    let uri = req.uri().path().to_string();
3425                    let mut body = req.into_body();
3426                    let req_tx = req_tx.clone();
3427                    async move {
3428                        let resp = match uri.as_str() {
3429                            "/client/llm_tokens" => serde_json::to_string(&json!({
3430                                "token": "test"
3431                            }))
3432                            .unwrap(),
3433                            "/predict_edits/raw" => {
3434                                let mut buf = Vec::new();
3435                                body.read_to_end(&mut buf).await.ok();
3436                                let req = serde_json::from_slice(&buf).unwrap();
3437
3438                                let (res_tx, res_rx) = oneshot::channel();
3439                                req_tx.unbounded_send((req, res_tx)).unwrap();
3440                                serde_json::to_string(&res_rx.await?).unwrap()
3441                            }
3442                            _ => {
3443                                panic!("Unexpected path: {}", uri)
3444                            }
3445                        };
3446
3447                        Ok(Response::builder().body(resp.into()).unwrap())
3448                    }
3449                }
3450            });
3451
3452            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3453            client.cloud_client().set_credentials(1, "test".into());
3454
3455            language_model::init(client.clone(), cx);
3456
3457            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3458            let zeta = Zeta::global(&client, &user_store, cx);
3459
3460            (zeta, req_rx)
3461        })
3462    }
3463}