zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   6use cloud_llm_client::{
   7    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::HashMap;
  13use edit_prediction_context::{
  14    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  15    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  16    SyntaxIndex, SyntaxIndexState,
  17};
  18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  19use futures::AsyncReadExt as _;
  20use futures::channel::{mpsc, oneshot};
  21use gpui::http_client::{AsyncBody, Method};
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task,
  24    WeakEntity, http_client, prelude::*,
  25};
  26use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
  27use language::{BufferSnapshot, OffsetRangeExt};
  28use language_model::{LlmApiToken, RefreshLlmTokenListener};
  29use lsp::DiagnosticSeverity;
  30use open_ai::FunctionDefinition;
  31use project::{Project, ProjectPath};
  32use release_channel::AppVersion;
  33use serde::de::DeserializeOwned;
  34use std::collections::{VecDeque, hash_map};
  35
  36use std::fmt::Write;
  37use std::ops::Range;
  38use std::path::Path;
  39use std::str::FromStr as _;
  40use std::sync::{Arc, LazyLock};
  41use std::time::{Duration, Instant};
  42use std::{env, mem};
  43use thiserror::Error;
  44use util::rel_path::RelPathBuf;
  45use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  46use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  47
  48pub mod assemble_excerpts;
  49mod prediction;
  50mod provider;
  51pub mod retrieval_search;
  52mod sweep_ai;
  53pub mod udiff;
  54mod xml_edits;
  55
  56use crate::assemble_excerpts::assemble_excerpts;
  57pub use crate::prediction::EditPrediction;
  58pub use crate::prediction::EditPredictionId;
  59pub use provider::ZetaEditPredictionProvider;
  60
  61/// Maximum number of events to track.
  62const EVENT_COUNT_MAX_SWEEP: usize = 6;
  63const EVENT_COUNT_MAX_ZETA: usize = 16;
  64const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  65
  66pub struct SweepFeatureFlag;
  67
  68impl FeatureFlag for SweepFeatureFlag {
  69    const NAME: &str = "sweep-ai";
  70}
  71pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  72    max_bytes: 512,
  73    min_bytes: 128,
  74    target_before_cursor_over_total_bytes: 0.5,
  75};
  76
  77pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  78    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  79
  80pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  81    excerpt: DEFAULT_EXCERPT_OPTIONS,
  82};
  83
  84pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  85    EditPredictionContextOptions {
  86        use_imports: true,
  87        max_retrieved_declarations: 0,
  88        excerpt: DEFAULT_EXCERPT_OPTIONS,
  89        score: EditPredictionScoreOptions {
  90            omit_excerpt_overlaps: true,
  91        },
  92    };
  93
  94pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  95    context: DEFAULT_CONTEXT_OPTIONS,
  96    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  97    max_diagnostic_bytes: 2048,
  98    prompt_format: PromptFormat::DEFAULT,
  99    file_indexing_parallelism: 1,
 100    buffer_change_grouping_interval: Duration::from_secs(1),
 101};
 102
 103static USE_OLLAMA: LazyLock<bool> =
 104    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 105static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 106    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 107        "qwen3-coder:30b".to_string()
 108    } else {
 109        "yqvev8r3".to_string()
 110    })
 111});
 112static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 113    match env::var("ZED_ZETA2_MODEL").as_deref() {
 114        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 115        Ok(model) => model,
 116        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 117        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 118    }
 119    .to_string()
 120});
 121static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 122    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 123        if *USE_OLLAMA {
 124            Some("http://localhost:11434/v1/chat/completions".into())
 125        } else {
 126            None
 127        }
 128    })
 129});
 130
 131pub struct Zeta2FeatureFlag;
 132
 133impl FeatureFlag for Zeta2FeatureFlag {
 134    const NAME: &'static str = "zeta2";
 135
 136    fn enabled_for_staff() -> bool {
 137        false
 138    }
 139}
 140
 141#[derive(Clone)]
 142struct ZetaGlobal(Entity<Zeta>);
 143
 144impl Global for ZetaGlobal {}
 145
 146pub struct Zeta {
 147    client: Arc<Client>,
 148    user_store: Entity<UserStore>,
 149    llm_token: LlmApiToken,
 150    _llm_token_subscription: Subscription,
 151    projects: HashMap<EntityId, ZetaProject>,
 152    options: ZetaOptions,
 153    update_required: bool,
 154    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 155    #[cfg(feature = "eval-support")]
 156    eval_cache: Option<Arc<dyn EvalCache>>,
 157    edit_prediction_model: ZetaEditPredictionModel,
 158    sweep_api_token: Option<String>,
 159    sweep_ai_debug_info: Arc<str>,
 160}
 161
 162#[derive(PartialEq, Eq)]
 163pub enum ZetaEditPredictionModel {
 164    ZedCloud,
 165    Sweep,
 166}
 167
 168#[derive(Debug, Clone, PartialEq)]
 169pub struct ZetaOptions {
 170    pub context: ContextMode,
 171    pub max_prompt_bytes: usize,
 172    pub max_diagnostic_bytes: usize,
 173    pub prompt_format: predict_edits_v3::PromptFormat,
 174    pub file_indexing_parallelism: usize,
 175    pub buffer_change_grouping_interval: Duration,
 176}
 177
 178#[derive(Debug, Clone, PartialEq)]
 179pub enum ContextMode {
 180    Agentic(AgenticContextOptions),
 181    Syntax(EditPredictionContextOptions),
 182}
 183
 184#[derive(Debug, Clone, PartialEq)]
 185pub struct AgenticContextOptions {
 186    pub excerpt: EditPredictionExcerptOptions,
 187}
 188
 189impl ContextMode {
 190    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 191        match self {
 192            ContextMode::Agentic(options) => &options.excerpt,
 193            ContextMode::Syntax(options) => &options.excerpt,
 194        }
 195    }
 196}
 197
 198#[derive(Debug)]
 199pub enum ZetaDebugInfo {
 200    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 201    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 202    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 203    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 204    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 205}
 206
 207#[derive(Debug)]
 208pub struct ZetaContextRetrievalStartedDebugInfo {
 209    pub project: Entity<Project>,
 210    pub timestamp: Instant,
 211    pub search_prompt: String,
 212}
 213
 214#[derive(Debug)]
 215pub struct ZetaContextRetrievalDebugInfo {
 216    pub project: Entity<Project>,
 217    pub timestamp: Instant,
 218}
 219
 220#[derive(Debug)]
 221pub struct ZetaEditPredictionDebugInfo {
 222    pub request: predict_edits_v3::PredictEditsRequest,
 223    pub retrieval_time: TimeDelta,
 224    pub buffer: WeakEntity<Buffer>,
 225    pub position: language::Anchor,
 226    pub local_prompt: Result<String, String>,
 227    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 228}
 229
 230#[derive(Debug)]
 231pub struct ZetaSearchQueryDebugInfo {
 232    pub project: Entity<Project>,
 233    pub timestamp: Instant,
 234    pub search_queries: Vec<SearchToolQuery>,
 235}
 236
 237pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 238
 239struct ZetaProject {
 240    syntax_index: Option<Entity<SyntaxIndex>>,
 241    events: VecDeque<Event>,
 242    recent_paths: VecDeque<ProjectPath>,
 243    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 244    current_prediction: Option<CurrentEditPrediction>,
 245    next_pending_prediction_id: usize,
 246    pending_predictions: ArrayVec<PendingPrediction, 2>,
 247    last_prediction_refresh: Option<(EntityId, Instant)>,
 248    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 249    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 250    refresh_context_debounce_task: Option<Task<Option<()>>>,
 251    refresh_context_timestamp: Option<Instant>,
 252    _subscription: gpui::Subscription,
 253}
 254
 255#[derive(Debug, Clone)]
 256struct CurrentEditPrediction {
 257    pub requested_by: PredictionRequestedBy,
 258    pub prediction: EditPrediction,
 259}
 260
 261impl CurrentEditPrediction {
 262    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 263        let Some(new_edits) = self
 264            .prediction
 265            .interpolate(&self.prediction.buffer.read(cx))
 266        else {
 267            return false;
 268        };
 269
 270        if self.prediction.buffer != old_prediction.prediction.buffer {
 271            return true;
 272        }
 273
 274        let Some(old_edits) = old_prediction
 275            .prediction
 276            .interpolate(&old_prediction.prediction.buffer.read(cx))
 277        else {
 278            return true;
 279        };
 280
 281        let requested_by_buffer_id = self.requested_by.buffer_id();
 282
 283        // This reduces the occurrence of UI thrash from replacing edits
 284        //
 285        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 286        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 287            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 288            && old_edits.len() == 1
 289            && new_edits.len() == 1
 290        {
 291            let (old_range, old_text) = &old_edits[0];
 292            let (new_range, new_text) = &new_edits[0];
 293            new_range == old_range && new_text.starts_with(old_text.as_ref())
 294        } else {
 295            true
 296        }
 297    }
 298}
 299
 300#[derive(Debug, Clone)]
 301enum PredictionRequestedBy {
 302    DiagnosticsUpdate,
 303    Buffer(EntityId),
 304}
 305
 306impl PredictionRequestedBy {
 307    pub fn buffer_id(&self) -> Option<EntityId> {
 308        match self {
 309            PredictionRequestedBy::DiagnosticsUpdate => None,
 310            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 311        }
 312    }
 313}
 314
 315struct PendingPrediction {
 316    id: usize,
 317    _task: Task<()>,
 318}
 319
 320/// A prediction from the perspective of a buffer.
 321#[derive(Debug)]
 322enum BufferEditPrediction<'a> {
 323    Local { prediction: &'a EditPrediction },
 324    Jump { prediction: &'a EditPrediction },
 325}
 326
 327struct RegisteredBuffer {
 328    snapshot: BufferSnapshot,
 329    _subscriptions: [gpui::Subscription; 2],
 330}
 331
 332#[derive(Clone)]
 333pub enum Event {
 334    BufferChange {
 335        old_snapshot: BufferSnapshot,
 336        new_snapshot: BufferSnapshot,
 337        end_edit_anchor: Option<Anchor>,
 338        timestamp: Instant,
 339    },
 340}
 341
 342impl Event {
 343    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 344        match self {
 345            Event::BufferChange {
 346                old_snapshot,
 347                new_snapshot,
 348                ..
 349            } => {
 350                let path = new_snapshot.file().map(|f| f.full_path(cx));
 351
 352                let old_path = old_snapshot.file().and_then(|f| {
 353                    let old_path = f.full_path(cx);
 354                    if Some(&old_path) != path.as_ref() {
 355                        Some(old_path)
 356                    } else {
 357                        None
 358                    }
 359                });
 360
 361                // TODO [zeta2] move to bg?
 362                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 363
 364                if path == old_path && diff.is_empty() {
 365                    None
 366                } else {
 367                    Some(predict_edits_v3::Event::BufferChange {
 368                        old_path,
 369                        path,
 370                        diff,
 371                        //todo: Actually detect if this edit was predicted or not
 372                        predicted: false,
 373                    })
 374                }
 375            }
 376        }
 377    }
 378
 379    pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
 380        match self {
 381            Event::BufferChange { new_snapshot, .. } => new_snapshot
 382                .file()
 383                .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
 384        }
 385    }
 386}
 387
 388impl Zeta {
 389    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 390        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 391    }
 392
 393    pub fn global(
 394        client: &Arc<Client>,
 395        user_store: &Entity<UserStore>,
 396        cx: &mut App,
 397    ) -> Entity<Self> {
 398        cx.try_global::<ZetaGlobal>()
 399            .map(|global| global.0.clone())
 400            .unwrap_or_else(|| {
 401                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 402                cx.set_global(ZetaGlobal(zeta.clone()));
 403                zeta
 404            })
 405    }
 406
 407    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 408        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 409
 410        Self {
 411            projects: HashMap::default(),
 412            client,
 413            user_store,
 414            options: DEFAULT_OPTIONS,
 415            llm_token: LlmApiToken::default(),
 416            _llm_token_subscription: cx.subscribe(
 417                &refresh_llm_token_listener,
 418                |this, _listener, _event, cx| {
 419                    let client = this.client.clone();
 420                    let llm_token = this.llm_token.clone();
 421                    cx.spawn(async move |_this, _cx| {
 422                        llm_token.refresh(&client).await?;
 423                        anyhow::Ok(())
 424                    })
 425                    .detach_and_log_err(cx);
 426                },
 427            ),
 428            update_required: false,
 429            debug_tx: None,
 430            #[cfg(feature = "eval-support")]
 431            eval_cache: None,
 432            edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
 433            sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
 434                .context("No SWEEP_AI_TOKEN environment variable set")
 435                .log_err(),
 436            sweep_ai_debug_info: sweep_ai::debug_info(cx),
 437        }
 438    }
 439
 440    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 441        self.edit_prediction_model = model;
 442    }
 443
 444    pub fn has_sweep_api_token(&self) -> bool {
 445        self.sweep_api_token.is_some()
 446    }
 447
 448    #[cfg(feature = "eval-support")]
 449    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 450        self.eval_cache = Some(cache);
 451    }
 452
 453    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 454        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 455        self.debug_tx = Some(debug_watch_tx);
 456        debug_watch_rx
 457    }
 458
 459    pub fn options(&self) -> &ZetaOptions {
 460        &self.options
 461    }
 462
 463    pub fn set_options(&mut self, options: ZetaOptions) {
 464        self.options = options;
 465    }
 466
 467    pub fn clear_history(&mut self) {
 468        for zeta_project in self.projects.values_mut() {
 469            zeta_project.events.clear();
 470        }
 471    }
 472
 473    pub fn history_for_project(
 474        &self,
 475        project: &Entity<Project>,
 476    ) -> impl DoubleEndedIterator<Item = &Event> {
 477        self.projects
 478            .get(&project.entity_id())
 479            .map(|project| project.events.iter())
 480            .into_iter()
 481            .flatten()
 482    }
 483
 484    pub fn context_for_project(
 485        &self,
 486        project: &Entity<Project>,
 487    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 488        self.projects
 489            .get(&project.entity_id())
 490            .and_then(|project| {
 491                Some(
 492                    project
 493                        .context
 494                        .as_ref()?
 495                        .iter()
 496                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 497                )
 498            })
 499            .into_iter()
 500            .flatten()
 501    }
 502
 503    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 504        if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
 505            self.user_store.read(cx).edit_prediction_usage()
 506        } else {
 507            None
 508        }
 509    }
 510
 511    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 512        self.get_or_init_zeta_project(project, cx);
 513    }
 514
 515    pub fn register_buffer(
 516        &mut self,
 517        buffer: &Entity<Buffer>,
 518        project: &Entity<Project>,
 519        cx: &mut Context<Self>,
 520    ) {
 521        let zeta_project = self.get_or_init_zeta_project(project, cx);
 522        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 523    }
 524
 525    fn get_or_init_zeta_project(
 526        &mut self,
 527        project: &Entity<Project>,
 528        cx: &mut Context<Self>,
 529    ) -> &mut ZetaProject {
 530        self.projects
 531            .entry(project.entity_id())
 532            .or_insert_with(|| ZetaProject {
 533                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 534                    Some(cx.new(|cx| {
 535                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 536                    }))
 537                } else {
 538                    None
 539                },
 540                events: VecDeque::new(),
 541                recent_paths: VecDeque::new(),
 542                registered_buffers: HashMap::default(),
 543                current_prediction: None,
 544                pending_predictions: ArrayVec::new(),
 545                next_pending_prediction_id: 0,
 546                last_prediction_refresh: None,
 547                context: None,
 548                refresh_context_task: None,
 549                refresh_context_debounce_task: None,
 550                refresh_context_timestamp: None,
 551                _subscription: cx.subscribe(&project, Self::handle_project_event),
 552            })
 553    }
 554
 555    fn handle_project_event(
 556        &mut self,
 557        project: Entity<Project>,
 558        event: &project::Event,
 559        cx: &mut Context<Self>,
 560    ) {
 561        // TODO [zeta2] init with recent paths
 562        match event {
 563            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 564                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 565                    return;
 566                };
 567                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 568                if let Some(path) = path {
 569                    if let Some(ix) = zeta_project
 570                        .recent_paths
 571                        .iter()
 572                        .position(|probe| probe == &path)
 573                    {
 574                        zeta_project.recent_paths.remove(ix);
 575                    }
 576                    zeta_project.recent_paths.push_front(path);
 577                }
 578            }
 579            project::Event::DiagnosticsUpdated { .. } => {
 580                self.refresh_prediction_from_diagnostics(project, cx);
 581            }
 582            _ => (),
 583        }
 584    }
 585
 586    fn register_buffer_impl<'a>(
 587        zeta_project: &'a mut ZetaProject,
 588        buffer: &Entity<Buffer>,
 589        project: &Entity<Project>,
 590        cx: &mut Context<Self>,
 591    ) -> &'a mut RegisteredBuffer {
 592        let buffer_id = buffer.entity_id();
 593        match zeta_project.registered_buffers.entry(buffer_id) {
 594            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 595            hash_map::Entry::Vacant(entry) => {
 596                let snapshot = buffer.read(cx).snapshot();
 597                let project_entity_id = project.entity_id();
 598                entry.insert(RegisteredBuffer {
 599                    snapshot,
 600                    _subscriptions: [
 601                        cx.subscribe(buffer, {
 602                            let project = project.downgrade();
 603                            move |this, buffer, event, cx| {
 604                                if let language::BufferEvent::Edited = event
 605                                    && let Some(project) = project.upgrade()
 606                                {
 607                                    this.report_changes_for_buffer(&buffer, &project, cx);
 608                                }
 609                            }
 610                        }),
 611                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 612                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 613                            else {
 614                                return;
 615                            };
 616                            zeta_project.registered_buffers.remove(&buffer_id);
 617                        }),
 618                    ],
 619                })
 620            }
 621        }
 622    }
 623
 624    fn report_changes_for_buffer(
 625        &mut self,
 626        buffer: &Entity<Buffer>,
 627        project: &Entity<Project>,
 628        cx: &mut Context<Self>,
 629    ) {
 630        let event_count_max = match self.edit_prediction_model {
 631            ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
 632            ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
 633        };
 634
 635        let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
 636        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
 637
 638        let new_snapshot = buffer.read(cx).snapshot();
 639        if new_snapshot.version == registered_buffer.snapshot.version {
 640            return;
 641        }
 642
 643        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 644        let end_edit_anchor = new_snapshot
 645            .anchored_edits_since::<Point>(&old_snapshot.version)
 646            .last()
 647            .map(|(_, range)| range.end);
 648        let events = &mut sweep_ai_project.events;
 649
 650        if let Some(Event::BufferChange {
 651            new_snapshot: last_new_snapshot,
 652            end_edit_anchor: last_end_edit_anchor,
 653            ..
 654        }) = events.back_mut()
 655        {
 656            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 657                == last_new_snapshot.remote_id()
 658                && old_snapshot.version == last_new_snapshot.version;
 659
 660            let should_coalesce = is_next_snapshot_of_same_buffer
 661                && end_edit_anchor
 662                    .as_ref()
 663                    .zip(last_end_edit_anchor.as_ref())
 664                    .is_some_and(|(a, b)| {
 665                        let a = a.to_point(&new_snapshot);
 666                        let b = b.to_point(&new_snapshot);
 667                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 668                    });
 669
 670            if should_coalesce {
 671                *last_end_edit_anchor = end_edit_anchor;
 672                *last_new_snapshot = new_snapshot;
 673                return;
 674            }
 675        }
 676
 677        if events.len() >= event_count_max {
 678            events.pop_front();
 679        }
 680
 681        events.push_back(Event::BufferChange {
 682            old_snapshot,
 683            new_snapshot,
 684            end_edit_anchor,
 685            timestamp: Instant::now(),
 686        });
 687    }
 688
 689    fn current_prediction_for_buffer(
 690        &self,
 691        buffer: &Entity<Buffer>,
 692        project: &Entity<Project>,
 693        cx: &App,
 694    ) -> Option<BufferEditPrediction<'_>> {
 695        let project_state = self.projects.get(&project.entity_id())?;
 696
 697        let CurrentEditPrediction {
 698            requested_by,
 699            prediction,
 700        } = project_state.current_prediction.as_ref()?;
 701
 702        if prediction.targets_buffer(buffer.read(cx)) {
 703            Some(BufferEditPrediction::Local { prediction })
 704        } else {
 705            let show_jump = match requested_by {
 706                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 707                    requested_by_buffer_id == &buffer.entity_id()
 708                }
 709                PredictionRequestedBy::DiagnosticsUpdate => true,
 710            };
 711
 712            if show_jump {
 713                Some(BufferEditPrediction::Jump { prediction })
 714            } else {
 715                None
 716            }
 717        }
 718    }
 719
 720    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 721        if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
 722            return;
 723        }
 724
 725        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 726            return;
 727        };
 728
 729        let Some(prediction) = project_state.current_prediction.take() else {
 730            return;
 731        };
 732        let request_id = prediction.prediction.id.to_string();
 733        project_state.pending_predictions.clear();
 734
 735        let client = self.client.clone();
 736        let llm_token = self.llm_token.clone();
 737        let app_version = AppVersion::global(cx);
 738        cx.spawn(async move |this, cx| {
 739            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 740                http_client::Url::parse(&predict_edits_url)?
 741            } else {
 742                client
 743                    .http_client()
 744                    .build_zed_llm_url("/predict_edits/accept", &[])?
 745            };
 746
 747            let response = cx
 748                .background_spawn(Self::send_api_request::<()>(
 749                    move |builder| {
 750                        let req = builder.uri(url.as_ref()).body(
 751                            serde_json::to_string(&AcceptEditPredictionBody {
 752                                request_id: request_id.clone(),
 753                            })?
 754                            .into(),
 755                        );
 756                        Ok(req?)
 757                    },
 758                    client,
 759                    llm_token,
 760                    app_version,
 761                ))
 762                .await;
 763
 764            Self::handle_api_response(&this, response, cx)?;
 765            anyhow::Ok(())
 766        })
 767        .detach_and_log_err(cx);
 768    }
 769
 770    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 771        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 772            project_state.current_prediction.take();
 773            project_state.pending_predictions.clear();
 774        };
 775    }
 776
 777    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
 778        self.projects
 779            .get(&project.entity_id())
 780            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
 781    }
 782
 783    pub fn refresh_prediction_from_buffer(
 784        &mut self,
 785        project: Entity<Project>,
 786        buffer: Entity<Buffer>,
 787        position: language::Anchor,
 788        cx: &mut Context<Self>,
 789    ) {
 790        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
 791            let Some(request_task) = this
 792                .update(cx, |this, cx| {
 793                    this.request_prediction(&project, &buffer, position, cx)
 794                })
 795                .log_err()
 796            else {
 797                return Task::ready(anyhow::Ok(()));
 798            };
 799
 800            let project = project.clone();
 801            cx.spawn(async move |cx| {
 802                if let Some(prediction) = request_task.await? {
 803                    this.update(cx, |this, cx| {
 804                        let project_state = this
 805                            .projects
 806                            .get_mut(&project.entity_id())
 807                            .context("Project not found")?;
 808
 809                        let new_prediction = CurrentEditPrediction {
 810                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
 811                            prediction: prediction,
 812                        };
 813
 814                        if project_state
 815                            .current_prediction
 816                            .as_ref()
 817                            .is_none_or(|old_prediction| {
 818                                new_prediction.should_replace_prediction(&old_prediction, cx)
 819                            })
 820                        {
 821                            project_state.current_prediction = Some(new_prediction);
 822                            cx.notify();
 823                        }
 824                        anyhow::Ok(())
 825                    })??;
 826                }
 827                Ok(())
 828            })
 829        })
 830    }
 831
 832    pub fn refresh_prediction_from_diagnostics(
 833        &mut self,
 834        project: Entity<Project>,
 835        cx: &mut Context<Self>,
 836    ) {
 837        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 838            return;
 839        };
 840
 841        // Prefer predictions from buffer
 842        if zeta_project.current_prediction.is_some() {
 843            return;
 844        };
 845
 846        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
 847            let Some(open_buffer_task) = project
 848                .update(cx, |project, cx| {
 849                    project
 850                        .active_entry()
 851                        .and_then(|entry| project.path_for_entry(entry, cx))
 852                        .map(|path| project.open_buffer(path, cx))
 853                })
 854                .log_err()
 855                .flatten()
 856            else {
 857                return Task::ready(anyhow::Ok(()));
 858            };
 859
 860            cx.spawn(async move |cx| {
 861                let active_buffer = open_buffer_task.await?;
 862                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 863
 864                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
 865                    active_buffer,
 866                    &snapshot,
 867                    Default::default(),
 868                    Default::default(),
 869                    &project,
 870                    cx,
 871                )
 872                .await?
 873                else {
 874                    return anyhow::Ok(());
 875                };
 876
 877                let Some(prediction) = this
 878                    .update(cx, |this, cx| {
 879                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
 880                    })?
 881                    .await?
 882                else {
 883                    return anyhow::Ok(());
 884                };
 885
 886                this.update(cx, |this, cx| {
 887                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
 888                        zeta_project.current_prediction.get_or_insert_with(|| {
 889                            cx.notify();
 890                            CurrentEditPrediction {
 891                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
 892                                prediction,
 893                            }
 894                        });
 895                    }
 896                })?;
 897
 898                anyhow::Ok(())
 899            })
 900        });
 901    }
 902
 903    #[cfg(not(test))]
 904    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 905    #[cfg(test)]
 906    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
 907
 908    fn queue_prediction_refresh(
 909        &mut self,
 910        project: Entity<Project>,
 911        throttle_entity: EntityId,
 912        cx: &mut Context<Self>,
 913        do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
 914    ) {
 915        let zeta_project = self.get_or_init_zeta_project(&project, cx);
 916        let pending_prediction_id = zeta_project.next_pending_prediction_id;
 917        zeta_project.next_pending_prediction_id += 1;
 918        let last_request = zeta_project.last_prediction_refresh;
 919
 920        // TODO report cancelled requests like in zeta1
 921        let task = cx.spawn(async move |this, cx| {
 922            if let Some((last_entity, last_timestamp)) = last_request
 923                && throttle_entity == last_entity
 924                && let Some(timeout) =
 925                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
 926            {
 927                cx.background_executor().timer(timeout).await;
 928            }
 929
 930            do_refresh(this.clone(), cx).await.log_err();
 931
 932            this.update(cx, |this, cx| {
 933                let zeta_project = this.get_or_init_zeta_project(&project, cx);
 934
 935                if zeta_project.pending_predictions[0].id == pending_prediction_id {
 936                    zeta_project.pending_predictions.remove(0);
 937                } else {
 938                    zeta_project.pending_predictions.clear();
 939                }
 940
 941                cx.notify();
 942            })
 943            .ok();
 944        });
 945
 946        if zeta_project.pending_predictions.len() <= 1 {
 947            zeta_project.pending_predictions.push(PendingPrediction {
 948                id: pending_prediction_id,
 949                _task: task,
 950            });
 951        } else if zeta_project.pending_predictions.len() == 2 {
 952            zeta_project.pending_predictions.pop();
 953            zeta_project.pending_predictions.push(PendingPrediction {
 954                id: pending_prediction_id,
 955                _task: task,
 956            });
 957        }
 958    }
 959
 960    pub fn request_prediction(
 961        &mut self,
 962        project: &Entity<Project>,
 963        active_buffer: &Entity<Buffer>,
 964        position: language::Anchor,
 965        cx: &mut Context<Self>,
 966    ) -> Task<Result<Option<EditPrediction>>> {
 967        match self.edit_prediction_model {
 968            ZetaEditPredictionModel::ZedCloud => {
 969                self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
 970            }
 971            ZetaEditPredictionModel::Sweep => {
 972                self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
 973            }
 974        }
 975    }
 976
 977    fn request_prediction_with_sweep(
 978        &mut self,
 979        project: &Entity<Project>,
 980        active_buffer: &Entity<Buffer>,
 981        position: language::Anchor,
 982        allow_jump: bool,
 983        cx: &mut Context<Self>,
 984    ) -> Task<Result<Option<EditPrediction>>> {
 985        let snapshot = active_buffer.read(cx).snapshot();
 986        let debug_info = self.sweep_ai_debug_info.clone();
 987        let Some(api_token) = self.sweep_api_token.clone() else {
 988            return Task::ready(Ok(None));
 989        };
 990        let full_path: Arc<Path> = snapshot
 991            .file()
 992            .map(|file| file.full_path(cx))
 993            .unwrap_or_else(|| "untitled".into())
 994            .into();
 995
 996        let project_file = project::File::from_dyn(snapshot.file());
 997        let repo_name = project_file
 998            .map(|file| file.worktree.read(cx).root_name_str())
 999            .unwrap_or("untitled")
1000            .into();
1001        let offset = position.to_offset(&snapshot);
1002
1003        let project_state = self.get_or_init_zeta_project(project, cx);
1004        let events = project_state.events.clone();
1005        let has_events = !events.is_empty();
1006        let recent_buffers = project_state.recent_paths.iter().cloned();
1007        let http_client = cx.http_client();
1008
1009        let recent_buffer_snapshots = recent_buffers
1010            .filter_map(|project_path| {
1011                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1012                if active_buffer == &buffer {
1013                    None
1014                } else {
1015                    Some(buffer.read(cx).snapshot())
1016                }
1017            })
1018            .take(3)
1019            .collect::<Vec<_>>();
1020
1021        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1022
1023        let cursor_point = position.to_point(&snapshot);
1024        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1025        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1026        let diagnostic_search_range =
1027            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1028
1029        let result = cx.background_spawn({
1030            let snapshot = snapshot.clone();
1031            let diagnostic_search_range = diagnostic_search_range.clone();
1032            async move {
1033                let text = snapshot.text();
1034
1035                let mut recent_changes = String::new();
1036                for event in events {
1037                    sweep_ai::write_event(event, &mut recent_changes).unwrap();
1038                }
1039
1040                let mut file_chunks = recent_buffer_snapshots
1041                    .into_iter()
1042                    .map(|snapshot| {
1043                        let end_point = Point::new(30, 0).min(snapshot.max_point());
1044                        sweep_ai::FileChunk {
1045                            content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1046                            file_path: snapshot
1047                                .file()
1048                                .map(|f| f.path().as_unix_str())
1049                                .unwrap_or("untitled")
1050                                .to_string(),
1051                            start_line: 0,
1052                            end_line: end_point.row as usize,
1053                            timestamp: snapshot.file().and_then(|file| {
1054                                Some(
1055                                    file.disk_state()
1056                                        .mtime()?
1057                                        .to_seconds_and_nanos_for_persistence()?
1058                                        .0,
1059                                )
1060                            }),
1061                        }
1062                    })
1063                    .collect::<Vec<_>>();
1064
1065                let diagnostic_entries =
1066                    snapshot.diagnostics_in_range(diagnostic_search_range, false);
1067                let mut diagnostic_content = String::new();
1068                let mut diagnostic_count = 0;
1069
1070                for entry in diagnostic_entries {
1071                    let start_point: Point = entry.range.start;
1072
1073                    let severity = match entry.diagnostic.severity {
1074                        DiagnosticSeverity::ERROR => "error",
1075                        DiagnosticSeverity::WARNING => "warning",
1076                        DiagnosticSeverity::INFORMATION => "info",
1077                        DiagnosticSeverity::HINT => "hint",
1078                        _ => continue,
1079                    };
1080
1081                    diagnostic_count += 1;
1082
1083                    writeln!(
1084                        &mut diagnostic_content,
1085                        "{} at line {}: {}",
1086                        severity,
1087                        start_point.row + 1,
1088                        entry.diagnostic.message
1089                    )?;
1090                }
1091
1092                if !diagnostic_content.is_empty() {
1093                    file_chunks.push(sweep_ai::FileChunk {
1094                        file_path: format!("Diagnostics for {}", full_path.display()),
1095                        start_line: 0,
1096                        end_line: diagnostic_count,
1097                        content: diagnostic_content,
1098                        timestamp: None,
1099                    });
1100                }
1101
1102                let request_body = sweep_ai::AutocompleteRequest {
1103                    debug_info,
1104                    repo_name,
1105                    file_path: full_path.clone(),
1106                    file_contents: text.clone(),
1107                    original_file_contents: text,
1108                    cursor_position: offset,
1109                    recent_changes: recent_changes.clone(),
1110                    changes_above_cursor: true,
1111                    multiple_suggestions: false,
1112                    branch: None,
1113                    file_chunks,
1114                    retrieval_chunks: vec![],
1115                    recent_user_actions: vec![],
1116                    // TODO
1117                    privacy_mode_enabled: false,
1118                };
1119
1120                let mut buf: Vec<u8> = Vec::new();
1121                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1122                serde_json::to_writer(writer, &request_body)?;
1123                let body: AsyncBody = buf.into();
1124
1125                const SWEEP_API_URL: &str =
1126                    "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1127
1128                let request = http_client::Request::builder()
1129                    .uri(SWEEP_API_URL)
1130                    .header("Content-Type", "application/json")
1131                    .header("Authorization", format!("Bearer {}", api_token))
1132                    .header("Connection", "keep-alive")
1133                    .header("Content-Encoding", "br")
1134                    .method(Method::POST)
1135                    .body(body)?;
1136
1137                let mut response = http_client.send(request).await?;
1138
1139                let mut body: Vec<u8> = Vec::new();
1140                response.body_mut().read_to_end(&mut body).await?;
1141
1142                if !response.status().is_success() {
1143                    anyhow::bail!(
1144                        "Request failed with status: {:?}\nBody: {}",
1145                        response.status(),
1146                        String::from_utf8_lossy(&body),
1147                    );
1148                };
1149
1150                let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1151
1152                let old_text = snapshot
1153                    .text_for_range(response.start_index..response.end_index)
1154                    .collect::<String>();
1155                let edits = language::text_diff(&old_text, &response.completion)
1156                    .into_iter()
1157                    .map(|(range, text)| {
1158                        (
1159                            snapshot.anchor_after(response.start_index + range.start)
1160                                ..snapshot.anchor_before(response.start_index + range.end),
1161                            text,
1162                        )
1163                    })
1164                    .collect::<Vec<_>>();
1165
1166                anyhow::Ok((response.autocomplete_id, edits, snapshot))
1167            }
1168        });
1169
1170        let buffer = active_buffer.clone();
1171        let project = project.clone();
1172        let active_buffer = active_buffer.clone();
1173
1174        cx.spawn(async move |this, cx| {
1175            let (id, edits, old_snapshot) = result.await?;
1176
1177            if edits.is_empty() {
1178                if has_events
1179                    && allow_jump
1180                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1181                        active_buffer,
1182                        &snapshot,
1183                        diagnostic_search_range,
1184                        cursor_point,
1185                        &project,
1186                        cx,
1187                    )
1188                    .await?
1189                {
1190                    return this
1191                        .update(cx, |this, cx| {
1192                            this.request_prediction_with_sweep(
1193                                &project,
1194                                &jump_buffer,
1195                                jump_position,
1196                                false,
1197                                cx,
1198                            )
1199                        })?
1200                        .await;
1201                }
1202
1203                return anyhow::Ok(None);
1204            }
1205
1206            let Some((edits, new_snapshot, preview_task)) =
1207                buffer.read_with(cx, |buffer, cx| {
1208                    let new_snapshot = buffer.snapshot();
1209
1210                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
1211                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
1212                            .into();
1213                    let preview_task = buffer.preview_edits(edits.clone(), cx);
1214
1215                    Some((edits, new_snapshot, preview_task))
1216                })?
1217            else {
1218                return anyhow::Ok(None);
1219            };
1220
1221            let prediction = EditPrediction {
1222                id: EditPredictionId(id.into()),
1223                edits,
1224                snapshot: new_snapshot,
1225                edit_preview: preview_task.await,
1226                buffer,
1227            };
1228
1229            anyhow::Ok(Some(prediction))
1230        })
1231    }
1232
1233    async fn next_diagnostic_location(
1234        active_buffer: Entity<Buffer>,
1235        active_buffer_snapshot: &BufferSnapshot,
1236        active_buffer_diagnostic_search_range: Range<Point>,
1237        active_buffer_cursor_point: Point,
1238        project: &Entity<Project>,
1239        cx: &mut AsyncApp,
1240    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1241        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1242        let mut jump_location = active_buffer_snapshot
1243            .diagnostic_groups(None)
1244            .into_iter()
1245            .filter_map(|(_, group)| {
1246                let range = &group.entries[group.primary_ix]
1247                    .range
1248                    .to_point(&active_buffer_snapshot);
1249                if range.overlaps(&active_buffer_diagnostic_search_range) {
1250                    None
1251                } else {
1252                    Some(range.start)
1253                }
1254            })
1255            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1256            .map(|position| {
1257                (
1258                    active_buffer.clone(),
1259                    active_buffer_snapshot.anchor_before(position),
1260                )
1261            });
1262
1263        if jump_location.is_none() {
1264            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1265                let file = buffer.file()?;
1266
1267                Some(ProjectPath {
1268                    worktree_id: file.worktree_id(cx),
1269                    path: file.path().clone(),
1270                })
1271            })?;
1272
1273            let buffer_task = project.update(cx, |project, cx| {
1274                let (path, _, _) = project
1275                    .diagnostic_summaries(false, cx)
1276                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1277                    .max_by_key(|(path, _, _)| {
1278                        // find the buffer with errors that shares most parent directories
1279                        path.path
1280                            .components()
1281                            .zip(
1282                                active_buffer_path
1283                                    .as_ref()
1284                                    .map(|p| p.path.components())
1285                                    .unwrap_or_default(),
1286                            )
1287                            .take_while(|(a, b)| a == b)
1288                            .count()
1289                    })?;
1290
1291                Some(project.open_buffer(path, cx))
1292            })?;
1293
1294            if let Some(buffer_task) = buffer_task {
1295                let closest_buffer = buffer_task.await?;
1296
1297                jump_location = closest_buffer
1298                    .read_with(cx, |buffer, _cx| {
1299                        buffer
1300                            .buffer_diagnostics(None)
1301                            .into_iter()
1302                            .min_by_key(|entry| entry.diagnostic.severity)
1303                            .map(|entry| entry.range.start)
1304                    })?
1305                    .map(|position| (closest_buffer, position));
1306            }
1307        }
1308
1309        anyhow::Ok(jump_location)
1310    }
1311
1312    fn request_prediction_with_zed_cloud(
1313        &mut self,
1314        project: &Entity<Project>,
1315        active_buffer: &Entity<Buffer>,
1316        position: language::Anchor,
1317        cx: &mut Context<Self>,
1318    ) -> Task<Result<Option<EditPrediction>>> {
1319        let project_state = self.projects.get(&project.entity_id());
1320
1321        let index_state = project_state.and_then(|state| {
1322            state
1323                .syntax_index
1324                .as_ref()
1325                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1326        });
1327        let options = self.options.clone();
1328        let active_snapshot = active_buffer.read(cx).snapshot();
1329        let Some(excerpt_path) = active_snapshot
1330            .file()
1331            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1332        else {
1333            return Task::ready(Err(anyhow!("No file path for excerpt")));
1334        };
1335        let client = self.client.clone();
1336        let llm_token = self.llm_token.clone();
1337        let app_version = AppVersion::global(cx);
1338        let worktree_snapshots = project
1339            .read(cx)
1340            .worktrees(cx)
1341            .map(|worktree| worktree.read(cx).snapshot())
1342            .collect::<Vec<_>>();
1343        let debug_tx = self.debug_tx.clone();
1344
1345        let events = project_state
1346            .map(|state| {
1347                state
1348                    .events
1349                    .iter()
1350                    .filter_map(|event| event.to_request_event(cx))
1351                    .collect::<Vec<_>>()
1352            })
1353            .unwrap_or_default();
1354
1355        let diagnostics = active_snapshot.diagnostic_sets().clone();
1356
1357        let parent_abs_path =
1358            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1359                let mut path = f.worktree.read(cx).absolutize(&f.path);
1360                if path.pop() { Some(path) } else { None }
1361            });
1362
1363        // TODO data collection
1364        let can_collect_data = cx.is_staff();
1365
1366        let empty_context_files = HashMap::default();
1367        let context_files = project_state
1368            .and_then(|project_state| project_state.context.as_ref())
1369            .unwrap_or(&empty_context_files);
1370
1371        #[cfg(feature = "eval-support")]
1372        let parsed_fut = futures::future::join_all(
1373            context_files
1374                .keys()
1375                .map(|buffer| buffer.read(cx).parsing_idle()),
1376        );
1377
1378        let mut included_files = context_files
1379            .iter()
1380            .filter_map(|(buffer_entity, ranges)| {
1381                let buffer = buffer_entity.read(cx);
1382                Some((
1383                    buffer_entity.clone(),
1384                    buffer.snapshot(),
1385                    buffer.file()?.full_path(cx).into(),
1386                    ranges.clone(),
1387                ))
1388            })
1389            .collect::<Vec<_>>();
1390
1391        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1392            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1393        });
1394
1395        #[cfg(feature = "eval-support")]
1396        let eval_cache = self.eval_cache.clone();
1397
1398        let request_task = cx.background_spawn({
1399            let active_buffer = active_buffer.clone();
1400            async move {
1401                #[cfg(feature = "eval-support")]
1402                parsed_fut.await;
1403
1404                let index_state = if let Some(index_state) = index_state {
1405                    Some(index_state.lock_owned().await)
1406                } else {
1407                    None
1408                };
1409
1410                let cursor_offset = position.to_offset(&active_snapshot);
1411                let cursor_point = cursor_offset.to_point(&active_snapshot);
1412
1413                let before_retrieval = chrono::Utc::now();
1414
1415                let (diagnostic_groups, diagnostic_groups_truncated) =
1416                    Self::gather_nearby_diagnostics(
1417                        cursor_offset,
1418                        &diagnostics,
1419                        &active_snapshot,
1420                        options.max_diagnostic_bytes,
1421                    );
1422
1423                let cloud_request = match options.context {
1424                    ContextMode::Agentic(context_options) => {
1425                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1426                            cursor_point,
1427                            &active_snapshot,
1428                            &context_options.excerpt,
1429                            index_state.as_deref(),
1430                        ) else {
1431                            return Ok((None, None));
1432                        };
1433
1434                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1435                            ..active_snapshot.anchor_before(excerpt.range.end);
1436
1437                        if let Some(buffer_ix) =
1438                            included_files.iter().position(|(_, snapshot, _, _)| {
1439                                snapshot.remote_id() == active_snapshot.remote_id()
1440                            })
1441                        {
1442                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1443                            ranges.push(excerpt_anchor_range);
1444                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1445                            let last_ix = included_files.len() - 1;
1446                            included_files.swap(buffer_ix, last_ix);
1447                        } else {
1448                            included_files.push((
1449                                active_buffer.clone(),
1450                                active_snapshot.clone(),
1451                                excerpt_path.clone(),
1452                                vec![excerpt_anchor_range],
1453                            ));
1454                        }
1455
1456                        let included_files = included_files
1457                            .iter()
1458                            .map(|(_, snapshot, path, ranges)| {
1459                                let ranges = ranges
1460                                    .iter()
1461                                    .map(|range| {
1462                                        let point_range = range.to_point(&snapshot);
1463                                        Line(point_range.start.row)..Line(point_range.end.row)
1464                                    })
1465                                    .collect::<Vec<_>>();
1466                                let excerpts = assemble_excerpts(&snapshot, ranges);
1467                                predict_edits_v3::IncludedFile {
1468                                    path: path.clone(),
1469                                    max_row: Line(snapshot.max_point().row),
1470                                    excerpts,
1471                                }
1472                            })
1473                            .collect::<Vec<_>>();
1474
1475                        predict_edits_v3::PredictEditsRequest {
1476                            excerpt_path,
1477                            excerpt: String::new(),
1478                            excerpt_line_range: Line(0)..Line(0),
1479                            excerpt_range: 0..0,
1480                            cursor_point: predict_edits_v3::Point {
1481                                line: predict_edits_v3::Line(cursor_point.row),
1482                                column: cursor_point.column,
1483                            },
1484                            included_files,
1485                            referenced_declarations: vec![],
1486                            events,
1487                            can_collect_data,
1488                            diagnostic_groups,
1489                            diagnostic_groups_truncated,
1490                            debug_info: debug_tx.is_some(),
1491                            prompt_max_bytes: Some(options.max_prompt_bytes),
1492                            prompt_format: options.prompt_format,
1493                            // TODO [zeta2]
1494                            signatures: vec![],
1495                            excerpt_parent: None,
1496                            git_info: None,
1497                        }
1498                    }
1499                    ContextMode::Syntax(context_options) => {
1500                        let Some(context) = EditPredictionContext::gather_context(
1501                            cursor_point,
1502                            &active_snapshot,
1503                            parent_abs_path.as_deref(),
1504                            &context_options,
1505                            index_state.as_deref(),
1506                        ) else {
1507                            return Ok((None, None));
1508                        };
1509
1510                        make_syntax_context_cloud_request(
1511                            excerpt_path,
1512                            context,
1513                            events,
1514                            can_collect_data,
1515                            diagnostic_groups,
1516                            diagnostic_groups_truncated,
1517                            None,
1518                            debug_tx.is_some(),
1519                            &worktree_snapshots,
1520                            index_state.as_deref(),
1521                            Some(options.max_prompt_bytes),
1522                            options.prompt_format,
1523                        )
1524                    }
1525                };
1526
1527                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1528
1529                let retrieval_time = chrono::Utc::now() - before_retrieval;
1530
1531                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1532                    let (response_tx, response_rx) = oneshot::channel();
1533
1534                    debug_tx
1535                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1536                            ZetaEditPredictionDebugInfo {
1537                                request: cloud_request.clone(),
1538                                retrieval_time,
1539                                buffer: active_buffer.downgrade(),
1540                                local_prompt: match prompt_result.as_ref() {
1541                                    Ok((prompt, _)) => Ok(prompt.clone()),
1542                                    Err(err) => Err(err.to_string()),
1543                                },
1544                                position,
1545                                response_rx,
1546                            },
1547                        ))
1548                        .ok();
1549                    Some(response_tx)
1550                } else {
1551                    None
1552                };
1553
1554                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1555                    if let Some(debug_response_tx) = debug_response_tx {
1556                        debug_response_tx
1557                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1558                            .ok();
1559                    }
1560                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1561                }
1562
1563                let (prompt, _) = prompt_result?;
1564                let request = open_ai::Request {
1565                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1566                    messages: vec![open_ai::RequestMessage::User {
1567                        content: open_ai::MessageContent::Plain(prompt),
1568                    }],
1569                    stream: false,
1570                    max_completion_tokens: None,
1571                    stop: Default::default(),
1572                    temperature: 0.7,
1573                    tool_choice: None,
1574                    parallel_tool_calls: None,
1575                    tools: vec![],
1576                    prompt_cache_key: None,
1577                    reasoning_effort: None,
1578                };
1579
1580                log::trace!("Sending edit prediction request");
1581
1582                let before_request = chrono::Utc::now();
1583                let response = Self::send_raw_llm_request(
1584                    request,
1585                    client,
1586                    llm_token,
1587                    app_version,
1588                    #[cfg(feature = "eval-support")]
1589                    eval_cache,
1590                    #[cfg(feature = "eval-support")]
1591                    EvalCacheEntryKind::Prediction,
1592                )
1593                .await;
1594                let request_time = chrono::Utc::now() - before_request;
1595
1596                log::trace!("Got edit prediction response");
1597
1598                if let Some(debug_response_tx) = debug_response_tx {
1599                    debug_response_tx
1600                        .send((
1601                            response
1602                                .as_ref()
1603                                .map_err(|err| err.to_string())
1604                                .map(|response| response.0.clone()),
1605                            request_time,
1606                        ))
1607                        .ok();
1608                }
1609
1610                let (res, usage) = response?;
1611                let request_id = EditPredictionId(res.id.clone().into());
1612                let Some(mut output_text) = text_from_response(res) else {
1613                    return Ok((None, usage));
1614                };
1615
1616                if output_text.contains(CURSOR_MARKER) {
1617                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1618                    output_text = output_text.replace(CURSOR_MARKER, "");
1619                }
1620
1621                let get_buffer_from_context = |path: &Path| {
1622                    included_files
1623                        .iter()
1624                        .find_map(|(_, buffer, probe_path, ranges)| {
1625                            if probe_path.as_ref() == path {
1626                                Some((buffer, ranges.as_slice()))
1627                            } else {
1628                                None
1629                            }
1630                        })
1631                };
1632
1633                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1634                    PromptFormat::NumLinesUniDiff => {
1635                        // TODO: Implement parsing of multi-file diffs
1636                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1637                    }
1638                    PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1639                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1640                            let edits = vec![];
1641                            (&active_snapshot, edits)
1642                        } else {
1643                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1644                        }
1645                    }
1646                    PromptFormat::OldTextNewText => {
1647                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1648                            .await?
1649                    }
1650                    _ => {
1651                        bail!("unsupported prompt format {}", options.prompt_format)
1652                    }
1653                };
1654
1655                let edited_buffer = included_files
1656                    .iter()
1657                    .find_map(|(buffer, snapshot, _, _)| {
1658                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1659                            Some(buffer.clone())
1660                        } else {
1661                            None
1662                        }
1663                    })
1664                    .context("Failed to find buffer in included_buffers")?;
1665
1666                anyhow::Ok((
1667                    Some((
1668                        request_id,
1669                        edited_buffer,
1670                        edited_buffer_snapshot.clone(),
1671                        edits,
1672                    )),
1673                    usage,
1674                ))
1675            }
1676        });
1677
1678        cx.spawn({
1679            async move |this, cx| {
1680                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1681                    Self::handle_api_response(&this, request_task.await, cx)?
1682                else {
1683                    return Ok(None);
1684                };
1685
1686                // TODO telemetry: duration, etc
1687                Ok(
1688                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1689                        .await,
1690                )
1691            }
1692        })
1693    }
1694
1695    async fn send_raw_llm_request(
1696        request: open_ai::Request,
1697        client: Arc<Client>,
1698        llm_token: LlmApiToken,
1699        app_version: SemanticVersion,
1700        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1701        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1702    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1703        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1704            http_client::Url::parse(&predict_edits_url)?
1705        } else {
1706            client
1707                .http_client()
1708                .build_zed_llm_url("/predict_edits/raw", &[])?
1709        };
1710
1711        #[cfg(feature = "eval-support")]
1712        let cache_key = if let Some(cache) = eval_cache {
1713            use collections::FxHasher;
1714            use std::hash::{Hash, Hasher};
1715
1716            let mut hasher = FxHasher::default();
1717            url.hash(&mut hasher);
1718            let request_str = serde_json::to_string_pretty(&request)?;
1719            request_str.hash(&mut hasher);
1720            let hash = hasher.finish();
1721
1722            let key = (eval_cache_kind, hash);
1723            if let Some(response_str) = cache.read(key) {
1724                return Ok((serde_json::from_str(&response_str)?, None));
1725            }
1726
1727            Some((cache, request_str, key))
1728        } else {
1729            None
1730        };
1731
1732        let (response, usage) = Self::send_api_request(
1733            |builder| {
1734                let req = builder
1735                    .uri(url.as_ref())
1736                    .body(serde_json::to_string(&request)?.into());
1737                Ok(req?)
1738            },
1739            client,
1740            llm_token,
1741            app_version,
1742        )
1743        .await?;
1744
1745        #[cfg(feature = "eval-support")]
1746        if let Some((cache, request, key)) = cache_key {
1747            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1748        }
1749
1750        Ok((response, usage))
1751    }
1752
1753    fn handle_api_response<T>(
1754        this: &WeakEntity<Self>,
1755        response: Result<(T, Option<EditPredictionUsage>)>,
1756        cx: &mut gpui::AsyncApp,
1757    ) -> Result<T> {
1758        match response {
1759            Ok((data, usage)) => {
1760                if let Some(usage) = usage {
1761                    this.update(cx, |this, cx| {
1762                        this.user_store.update(cx, |user_store, cx| {
1763                            user_store.update_edit_prediction_usage(usage, cx);
1764                        });
1765                    })
1766                    .ok();
1767                }
1768                Ok(data)
1769            }
1770            Err(err) => {
1771                if err.is::<ZedUpdateRequiredError>() {
1772                    cx.update(|cx| {
1773                        this.update(cx, |this, _cx| {
1774                            this.update_required = true;
1775                        })
1776                        .ok();
1777
1778                        let error_message: SharedString = err.to_string().into();
1779                        show_app_notification(
1780                            NotificationId::unique::<ZedUpdateRequiredError>(),
1781                            cx,
1782                            move |cx| {
1783                                cx.new(|cx| {
1784                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1785                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1786                                })
1787                            },
1788                        );
1789                    })
1790                    .ok();
1791                }
1792                Err(err)
1793            }
1794        }
1795    }
1796
1797    async fn send_api_request<Res>(
1798        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1799        client: Arc<Client>,
1800        llm_token: LlmApiToken,
1801        app_version: SemanticVersion,
1802    ) -> Result<(Res, Option<EditPredictionUsage>)>
1803    where
1804        Res: DeserializeOwned,
1805    {
1806        let http_client = client.http_client();
1807        let mut token = llm_token.acquire(&client).await?;
1808        let mut did_retry = false;
1809
1810        loop {
1811            let request_builder = http_client::Request::builder().method(Method::POST);
1812
1813            let request = build(
1814                request_builder
1815                    .header("Content-Type", "application/json")
1816                    .header("Authorization", format!("Bearer {}", token))
1817                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1818            )?;
1819
1820            let mut response = http_client.send(request).await?;
1821
1822            if let Some(minimum_required_version) = response
1823                .headers()
1824                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1825                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1826            {
1827                anyhow::ensure!(
1828                    app_version >= minimum_required_version,
1829                    ZedUpdateRequiredError {
1830                        minimum_version: minimum_required_version
1831                    }
1832                );
1833            }
1834
1835            if response.status().is_success() {
1836                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1837
1838                let mut body = Vec::new();
1839                response.body_mut().read_to_end(&mut body).await?;
1840                return Ok((serde_json::from_slice(&body)?, usage));
1841            } else if !did_retry
1842                && response
1843                    .headers()
1844                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1845                    .is_some()
1846            {
1847                did_retry = true;
1848                token = llm_token.refresh(&client).await?;
1849            } else {
1850                let mut body = String::new();
1851                response.body_mut().read_to_string(&mut body).await?;
1852                anyhow::bail!(
1853                    "Request failed with status: {:?}\nBody: {}",
1854                    response.status(),
1855                    body
1856                );
1857            }
1858        }
1859    }
1860
1861    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1862    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1863
1864    // Refresh the related excerpts when the user just beguns editing after
1865    // an idle period, and after they pause editing.
1866    fn refresh_context_if_needed(
1867        &mut self,
1868        project: &Entity<Project>,
1869        buffer: &Entity<language::Buffer>,
1870        cursor_position: language::Anchor,
1871        cx: &mut Context<Self>,
1872    ) {
1873        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1874            return;
1875        }
1876
1877        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1878            return;
1879        };
1880
1881        let now = Instant::now();
1882        let was_idle = zeta_project
1883            .refresh_context_timestamp
1884            .map_or(true, |timestamp| {
1885                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1886            });
1887        zeta_project.refresh_context_timestamp = Some(now);
1888        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1889            let buffer = buffer.clone();
1890            let project = project.clone();
1891            async move |this, cx| {
1892                if was_idle {
1893                    log::debug!("refetching edit prediction context after idle");
1894                } else {
1895                    cx.background_executor()
1896                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1897                        .await;
1898                    log::debug!("refetching edit prediction context after pause");
1899                }
1900                this.update(cx, |this, cx| {
1901                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1902
1903                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1904                        zeta_project.refresh_context_task = Some(task.log_err());
1905                    };
1906                })
1907                .ok()
1908            }
1909        }));
1910    }
1911
1912    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1913    // and avoid spawning more than one concurrent task.
1914    pub fn refresh_context(
1915        &mut self,
1916        project: Entity<Project>,
1917        buffer: Entity<language::Buffer>,
1918        cursor_position: language::Anchor,
1919        cx: &mut Context<Self>,
1920    ) -> Task<Result<()>> {
1921        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1922            return Task::ready(anyhow::Ok(()));
1923        };
1924
1925        let ContextMode::Agentic(options) = &self.options().context else {
1926            return Task::ready(anyhow::Ok(()));
1927        };
1928
1929        let snapshot = buffer.read(cx).snapshot();
1930        let cursor_point = cursor_position.to_point(&snapshot);
1931        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1932            cursor_point,
1933            &snapshot,
1934            &options.excerpt,
1935            None,
1936        ) else {
1937            return Task::ready(Ok(()));
1938        };
1939
1940        let app_version = AppVersion::global(cx);
1941        let client = self.client.clone();
1942        let llm_token = self.llm_token.clone();
1943        let debug_tx = self.debug_tx.clone();
1944        let current_file_path: Arc<Path> = snapshot
1945            .file()
1946            .map(|f| f.full_path(cx).into())
1947            .unwrap_or_else(|| Path::new("untitled").into());
1948
1949        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1950            predict_edits_v3::PlanContextRetrievalRequest {
1951                excerpt: cursor_excerpt.text(&snapshot).body,
1952                excerpt_path: current_file_path,
1953                excerpt_line_range: cursor_excerpt.line_range,
1954                cursor_file_max_row: Line(snapshot.max_point().row),
1955                events: zeta_project
1956                    .events
1957                    .iter()
1958                    .filter_map(|ev| ev.to_request_event(cx))
1959                    .collect(),
1960            },
1961        ) {
1962            Ok(prompt) => prompt,
1963            Err(err) => {
1964                return Task::ready(Err(err));
1965            }
1966        };
1967
1968        if let Some(debug_tx) = &debug_tx {
1969            debug_tx
1970                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1971                    ZetaContextRetrievalStartedDebugInfo {
1972                        project: project.clone(),
1973                        timestamp: Instant::now(),
1974                        search_prompt: prompt.clone(),
1975                    },
1976                ))
1977                .ok();
1978        }
1979
1980        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1981            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1982                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1983            );
1984
1985            let description = schema
1986                .get("description")
1987                .and_then(|description| description.as_str())
1988                .unwrap()
1989                .to_string();
1990
1991            (schema.into(), description)
1992        });
1993
1994        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1995
1996        let request = open_ai::Request {
1997            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1998            messages: vec![open_ai::RequestMessage::User {
1999                content: open_ai::MessageContent::Plain(prompt),
2000            }],
2001            stream: false,
2002            max_completion_tokens: None,
2003            stop: Default::default(),
2004            temperature: 0.7,
2005            tool_choice: None,
2006            parallel_tool_calls: None,
2007            tools: vec![open_ai::ToolDefinition::Function {
2008                function: FunctionDefinition {
2009                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2010                    description: Some(tool_description),
2011                    parameters: Some(tool_schema),
2012                },
2013            }],
2014            prompt_cache_key: None,
2015            reasoning_effort: None,
2016        };
2017
2018        #[cfg(feature = "eval-support")]
2019        let eval_cache = self.eval_cache.clone();
2020
2021        cx.spawn(async move |this, cx| {
2022            log::trace!("Sending search planning request");
2023            let response = Self::send_raw_llm_request(
2024                request,
2025                client,
2026                llm_token,
2027                app_version,
2028                #[cfg(feature = "eval-support")]
2029                eval_cache.clone(),
2030                #[cfg(feature = "eval-support")]
2031                EvalCacheEntryKind::Context,
2032            )
2033            .await;
2034            let mut response = Self::handle_api_response(&this, response, cx)?;
2035            log::trace!("Got search planning response");
2036
2037            let choice = response
2038                .choices
2039                .pop()
2040                .context("No choices in retrieval response")?;
2041            let open_ai::RequestMessage::Assistant {
2042                content: _,
2043                tool_calls,
2044            } = choice.message
2045            else {
2046                anyhow::bail!("Retrieval response didn't include an assistant message");
2047            };
2048
2049            let mut queries: Vec<SearchToolQuery> = Vec::new();
2050            for tool_call in tool_calls {
2051                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2052                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2053                    log::warn!(
2054                        "Context retrieval response tried to call an unknown tool: {}",
2055                        function.name
2056                    );
2057
2058                    continue;
2059                }
2060
2061                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2062                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2063                queries.extend(input.queries);
2064            }
2065
2066            if let Some(debug_tx) = &debug_tx {
2067                debug_tx
2068                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2069                        ZetaSearchQueryDebugInfo {
2070                            project: project.clone(),
2071                            timestamp: Instant::now(),
2072                            search_queries: queries.clone(),
2073                        },
2074                    ))
2075                    .ok();
2076            }
2077
2078            log::trace!("Running retrieval search: {queries:#?}");
2079
2080            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2081                queries,
2082                project.clone(),
2083                #[cfg(feature = "eval-support")]
2084                eval_cache,
2085                cx,
2086            )
2087            .await;
2088
2089            log::trace!("Search queries executed");
2090
2091            if let Some(debug_tx) = &debug_tx {
2092                debug_tx
2093                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2094                        ZetaContextRetrievalDebugInfo {
2095                            project: project.clone(),
2096                            timestamp: Instant::now(),
2097                        },
2098                    ))
2099                    .ok();
2100            }
2101
2102            this.update(cx, |this, _cx| {
2103                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2104                    return Ok(());
2105                };
2106                zeta_project.refresh_context_task.take();
2107                if let Some(debug_tx) = &this.debug_tx {
2108                    debug_tx
2109                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2110                            ZetaContextRetrievalDebugInfo {
2111                                project,
2112                                timestamp: Instant::now(),
2113                            },
2114                        ))
2115                        .ok();
2116                }
2117                match related_excerpts_result {
2118                    Ok(excerpts) => {
2119                        zeta_project.context = Some(excerpts);
2120                        Ok(())
2121                    }
2122                    Err(error) => Err(error),
2123                }
2124            })?
2125        })
2126    }
2127
2128    pub fn set_context(
2129        &mut self,
2130        project: Entity<Project>,
2131        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2132    ) {
2133        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2134            zeta_project.context = Some(context);
2135        }
2136    }
2137
2138    fn gather_nearby_diagnostics(
2139        cursor_offset: usize,
2140        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2141        snapshot: &BufferSnapshot,
2142        max_diagnostics_bytes: usize,
2143    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2144        // TODO: Could make this more efficient
2145        let mut diagnostic_groups = Vec::new();
2146        for (language_server_id, diagnostics) in diagnostic_sets {
2147            let mut groups = Vec::new();
2148            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2149            diagnostic_groups.extend(
2150                groups
2151                    .into_iter()
2152                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2153            );
2154        }
2155
2156        // sort by proximity to cursor
2157        diagnostic_groups.sort_by_key(|group| {
2158            let range = &group.entries[group.primary_ix].range;
2159            if range.start >= cursor_offset {
2160                range.start - cursor_offset
2161            } else if cursor_offset >= range.end {
2162                cursor_offset - range.end
2163            } else {
2164                (cursor_offset - range.start).min(range.end - cursor_offset)
2165            }
2166        });
2167
2168        let mut results = Vec::new();
2169        let mut diagnostic_groups_truncated = false;
2170        let mut diagnostics_byte_count = 0;
2171        for group in diagnostic_groups {
2172            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2173            diagnostics_byte_count += raw_value.get().len();
2174            if diagnostics_byte_count > max_diagnostics_bytes {
2175                diagnostic_groups_truncated = true;
2176                break;
2177            }
2178            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2179        }
2180
2181        (results, diagnostic_groups_truncated)
2182    }
2183
2184    // TODO: Dedupe with similar code in request_prediction?
2185    pub fn cloud_request_for_zeta_cli(
2186        &mut self,
2187        project: &Entity<Project>,
2188        buffer: &Entity<Buffer>,
2189        position: language::Anchor,
2190        cx: &mut Context<Self>,
2191    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2192        let project_state = self.projects.get(&project.entity_id());
2193
2194        let index_state = project_state.and_then(|state| {
2195            state
2196                .syntax_index
2197                .as_ref()
2198                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2199        });
2200        let options = self.options.clone();
2201        let snapshot = buffer.read(cx).snapshot();
2202        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2203            return Task::ready(Err(anyhow!("No file path for excerpt")));
2204        };
2205        let worktree_snapshots = project
2206            .read(cx)
2207            .worktrees(cx)
2208            .map(|worktree| worktree.read(cx).snapshot())
2209            .collect::<Vec<_>>();
2210
2211        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2212            let mut path = f.worktree.read(cx).absolutize(&f.path);
2213            if path.pop() { Some(path) } else { None }
2214        });
2215
2216        cx.background_spawn(async move {
2217            let index_state = if let Some(index_state) = index_state {
2218                Some(index_state.lock_owned().await)
2219            } else {
2220                None
2221            };
2222
2223            let cursor_point = position.to_point(&snapshot);
2224
2225            let debug_info = true;
2226            EditPredictionContext::gather_context(
2227                cursor_point,
2228                &snapshot,
2229                parent_abs_path.as_deref(),
2230                match &options.context {
2231                    ContextMode::Agentic(_) => {
2232                        // TODO
2233                        panic!("Llm mode not supported in zeta cli yet");
2234                    }
2235                    ContextMode::Syntax(edit_prediction_context_options) => {
2236                        edit_prediction_context_options
2237                    }
2238                },
2239                index_state.as_deref(),
2240            )
2241            .context("Failed to select excerpt")
2242            .map(|context| {
2243                make_syntax_context_cloud_request(
2244                    excerpt_path.into(),
2245                    context,
2246                    // TODO pass everything
2247                    Vec::new(),
2248                    false,
2249                    Vec::new(),
2250                    false,
2251                    None,
2252                    debug_info,
2253                    &worktree_snapshots,
2254                    index_state.as_deref(),
2255                    Some(options.max_prompt_bytes),
2256                    options.prompt_format,
2257                )
2258            })
2259        })
2260    }
2261
2262    pub fn wait_for_initial_indexing(
2263        &mut self,
2264        project: &Entity<Project>,
2265        cx: &mut Context<Self>,
2266    ) -> Task<Result<()>> {
2267        let zeta_project = self.get_or_init_zeta_project(project, cx);
2268        if let Some(syntax_index) = &zeta_project.syntax_index {
2269            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2270        } else {
2271            Task::ready(Ok(()))
2272        }
2273    }
2274}
2275
2276pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2277    let choice = res.choices.pop()?;
2278    let output_text = match choice.message {
2279        open_ai::RequestMessage::Assistant {
2280            content: Some(open_ai::MessageContent::Plain(content)),
2281            ..
2282        } => content,
2283        open_ai::RequestMessage::Assistant {
2284            content: Some(open_ai::MessageContent::Multipart(mut content)),
2285            ..
2286        } => {
2287            if content.is_empty() {
2288                log::error!("No output from Baseten completion response");
2289                return None;
2290            }
2291
2292            match content.remove(0) {
2293                open_ai::MessagePart::Text { text } => text,
2294                open_ai::MessagePart::Image { .. } => {
2295                    log::error!("Expected text, got an image");
2296                    return None;
2297                }
2298            }
2299        }
2300        _ => {
2301            log::error!("Invalid response message: {:?}", choice.message);
2302            return None;
2303        }
2304    };
2305    Some(output_text)
2306}
2307
2308#[derive(Error, Debug)]
2309#[error(
2310    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2311)]
2312pub struct ZedUpdateRequiredError {
2313    minimum_version: SemanticVersion,
2314}
2315
2316fn make_syntax_context_cloud_request(
2317    excerpt_path: Arc<Path>,
2318    context: EditPredictionContext,
2319    events: Vec<predict_edits_v3::Event>,
2320    can_collect_data: bool,
2321    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2322    diagnostic_groups_truncated: bool,
2323    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2324    debug_info: bool,
2325    worktrees: &Vec<worktree::Snapshot>,
2326    index_state: Option<&SyntaxIndexState>,
2327    prompt_max_bytes: Option<usize>,
2328    prompt_format: PromptFormat,
2329) -> predict_edits_v3::PredictEditsRequest {
2330    let mut signatures = Vec::new();
2331    let mut declaration_to_signature_index = HashMap::default();
2332    let mut referenced_declarations = Vec::new();
2333
2334    for snippet in context.declarations {
2335        let project_entry_id = snippet.declaration.project_entry_id();
2336        let Some(path) = worktrees.iter().find_map(|worktree| {
2337            worktree.entry_for_id(project_entry_id).map(|entry| {
2338                let mut full_path = RelPathBuf::new();
2339                full_path.push(worktree.root_name());
2340                full_path.push(&entry.path);
2341                full_path
2342            })
2343        }) else {
2344            continue;
2345        };
2346
2347        let parent_index = index_state.and_then(|index_state| {
2348            snippet.declaration.parent().and_then(|parent| {
2349                add_signature(
2350                    parent,
2351                    &mut declaration_to_signature_index,
2352                    &mut signatures,
2353                    index_state,
2354                )
2355            })
2356        });
2357
2358        let (text, text_is_truncated) = snippet.declaration.item_text();
2359        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2360            path: path.as_std_path().into(),
2361            text: text.into(),
2362            range: snippet.declaration.item_line_range(),
2363            text_is_truncated,
2364            signature_range: snippet.declaration.signature_range_in_item_text(),
2365            parent_index,
2366            signature_score: snippet.score(DeclarationStyle::Signature),
2367            declaration_score: snippet.score(DeclarationStyle::Declaration),
2368            score_components: snippet.components,
2369        });
2370    }
2371
2372    let excerpt_parent = index_state.and_then(|index_state| {
2373        context
2374            .excerpt
2375            .parent_declarations
2376            .last()
2377            .and_then(|(parent, _)| {
2378                add_signature(
2379                    *parent,
2380                    &mut declaration_to_signature_index,
2381                    &mut signatures,
2382                    index_state,
2383                )
2384            })
2385    });
2386
2387    predict_edits_v3::PredictEditsRequest {
2388        excerpt_path,
2389        excerpt: context.excerpt_text.body,
2390        excerpt_line_range: context.excerpt.line_range,
2391        excerpt_range: context.excerpt.range,
2392        cursor_point: predict_edits_v3::Point {
2393            line: predict_edits_v3::Line(context.cursor_point.row),
2394            column: context.cursor_point.column,
2395        },
2396        referenced_declarations,
2397        included_files: vec![],
2398        signatures,
2399        excerpt_parent,
2400        events,
2401        can_collect_data,
2402        diagnostic_groups,
2403        diagnostic_groups_truncated,
2404        git_info,
2405        debug_info,
2406        prompt_max_bytes,
2407        prompt_format,
2408    }
2409}
2410
2411fn add_signature(
2412    declaration_id: DeclarationId,
2413    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2414    signatures: &mut Vec<Signature>,
2415    index: &SyntaxIndexState,
2416) -> Option<usize> {
2417    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2418        return Some(*signature_index);
2419    }
2420    let Some(parent_declaration) = index.declaration(declaration_id) else {
2421        log::error!("bug: missing parent declaration");
2422        return None;
2423    };
2424    let parent_index = parent_declaration.parent().and_then(|parent| {
2425        add_signature(parent, declaration_to_signature_index, signatures, index)
2426    });
2427    let (text, text_is_truncated) = parent_declaration.signature_text();
2428    let signature_index = signatures.len();
2429    signatures.push(Signature {
2430        text: text.into(),
2431        text_is_truncated,
2432        parent_index,
2433        range: parent_declaration.signature_line_range(),
2434    });
2435    declaration_to_signature_index.insert(declaration_id, signature_index);
2436    Some(signature_index)
2437}
2438
2439#[cfg(feature = "eval-support")]
2440pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2441
2442#[cfg(feature = "eval-support")]
2443#[derive(Debug, Clone, Copy, PartialEq)]
2444pub enum EvalCacheEntryKind {
2445    Context,
2446    Search,
2447    Prediction,
2448}
2449
2450#[cfg(feature = "eval-support")]
2451impl std::fmt::Display for EvalCacheEntryKind {
2452    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2453        match self {
2454            EvalCacheEntryKind::Search => write!(f, "search"),
2455            EvalCacheEntryKind::Context => write!(f, "context"),
2456            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2457        }
2458    }
2459}
2460
2461#[cfg(feature = "eval-support")]
2462pub trait EvalCache: Send + Sync {
2463    fn read(&self, key: EvalCacheKey) -> Option<String>;
2464    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2465}
2466
2467#[cfg(test)]
2468mod tests {
2469    use std::{path::Path, sync::Arc};
2470
2471    use client::UserStore;
2472    use clock::FakeSystemClock;
2473    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2474    use futures::{
2475        AsyncReadExt, StreamExt,
2476        channel::{mpsc, oneshot},
2477    };
2478    use gpui::{
2479        Entity, TestAppContext,
2480        http_client::{FakeHttpClient, Response},
2481        prelude::*,
2482    };
2483    use indoc::indoc;
2484    use language::OffsetRangeExt as _;
2485    use open_ai::Usage;
2486    use pretty_assertions::{assert_eq, assert_matches};
2487    use project::{FakeFs, Project};
2488    use serde_json::json;
2489    use settings::SettingsStore;
2490    use util::path;
2491    use uuid::Uuid;
2492
2493    use crate::{BufferEditPrediction, Zeta};
2494
2495    #[gpui::test]
2496    async fn test_current_state(cx: &mut TestAppContext) {
2497        let (zeta, mut req_rx) = init_test(cx);
2498        let fs = FakeFs::new(cx.executor());
2499        fs.insert_tree(
2500            "/root",
2501            json!({
2502                "1.txt": "Hello!\nHow\nBye\n",
2503                "2.txt": "Hola!\nComo\nAdios\n"
2504            }),
2505        )
2506        .await;
2507        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2508
2509        zeta.update(cx, |zeta, cx| {
2510            zeta.register_project(&project, cx);
2511        });
2512
2513        let buffer1 = project
2514            .update(cx, |project, cx| {
2515                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2516                project.open_buffer(path, cx)
2517            })
2518            .await
2519            .unwrap();
2520        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2521        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2522
2523        // Prediction for current file
2524
2525        zeta.update(cx, |zeta, cx| {
2526            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2527        });
2528        let (_request, respond_tx) = req_rx.next().await.unwrap();
2529
2530        respond_tx
2531            .send(model_response(indoc! {r"
2532                --- a/root/1.txt
2533                +++ b/root/1.txt
2534                @@ ... @@
2535                 Hello!
2536                -How
2537                +How are you?
2538                 Bye
2539            "}))
2540            .unwrap();
2541
2542        cx.run_until_parked();
2543
2544        zeta.read_with(cx, |zeta, cx| {
2545            let prediction = zeta
2546                .current_prediction_for_buffer(&buffer1, &project, cx)
2547                .unwrap();
2548            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2549        });
2550
2551        // Context refresh
2552        let refresh_task = zeta.update(cx, |zeta, cx| {
2553            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2554        });
2555        let (_request, respond_tx) = req_rx.next().await.unwrap();
2556        respond_tx
2557            .send(open_ai::Response {
2558                id: Uuid::new_v4().to_string(),
2559                object: "response".into(),
2560                created: 0,
2561                model: "model".into(),
2562                choices: vec![open_ai::Choice {
2563                    index: 0,
2564                    message: open_ai::RequestMessage::Assistant {
2565                        content: None,
2566                        tool_calls: vec![open_ai::ToolCall {
2567                            id: "search".into(),
2568                            content: open_ai::ToolCallContent::Function {
2569                                function: open_ai::FunctionContent {
2570                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2571                                        .to_string(),
2572                                    arguments: serde_json::to_string(&SearchToolInput {
2573                                        queries: Box::new([SearchToolQuery {
2574                                            glob: "root/2.txt".to_string(),
2575                                            syntax_node: vec![],
2576                                            content: Some(".".into()),
2577                                        }]),
2578                                    })
2579                                    .unwrap(),
2580                                },
2581                            },
2582                        }],
2583                    },
2584                    finish_reason: None,
2585                }],
2586                usage: Usage {
2587                    prompt_tokens: 0,
2588                    completion_tokens: 0,
2589                    total_tokens: 0,
2590                },
2591            })
2592            .unwrap();
2593        refresh_task.await.unwrap();
2594
2595        zeta.update(cx, |zeta, _cx| {
2596            zeta.discard_current_prediction(&project);
2597        });
2598
2599        // Prediction for another file
2600        zeta.update(cx, |zeta, cx| {
2601            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2602        });
2603        let (_request, respond_tx) = req_rx.next().await.unwrap();
2604        respond_tx
2605            .send(model_response(indoc! {r#"
2606                --- a/root/2.txt
2607                +++ b/root/2.txt
2608                 Hola!
2609                -Como
2610                +Como estas?
2611                 Adios
2612            "#}))
2613            .unwrap();
2614        cx.run_until_parked();
2615
2616        zeta.read_with(cx, |zeta, cx| {
2617            let prediction = zeta
2618                .current_prediction_for_buffer(&buffer1, &project, cx)
2619                .unwrap();
2620            assert_matches!(
2621                prediction,
2622                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2623            );
2624        });
2625
2626        let buffer2 = project
2627            .update(cx, |project, cx| {
2628                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2629                project.open_buffer(path, cx)
2630            })
2631            .await
2632            .unwrap();
2633
2634        zeta.read_with(cx, |zeta, cx| {
2635            let prediction = zeta
2636                .current_prediction_for_buffer(&buffer2, &project, cx)
2637                .unwrap();
2638            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2639        });
2640    }
2641
2642    #[gpui::test]
2643    async fn test_simple_request(cx: &mut TestAppContext) {
2644        let (zeta, mut req_rx) = init_test(cx);
2645        let fs = FakeFs::new(cx.executor());
2646        fs.insert_tree(
2647            "/root",
2648            json!({
2649                "foo.md":  "Hello!\nHow\nBye\n"
2650            }),
2651        )
2652        .await;
2653        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2654
2655        let buffer = project
2656            .update(cx, |project, cx| {
2657                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2658                project.open_buffer(path, cx)
2659            })
2660            .await
2661            .unwrap();
2662        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2663        let position = snapshot.anchor_before(language::Point::new(1, 3));
2664
2665        let prediction_task = zeta.update(cx, |zeta, cx| {
2666            zeta.request_prediction(&project, &buffer, position, cx)
2667        });
2668
2669        let (_, respond_tx) = req_rx.next().await.unwrap();
2670
2671        // TODO Put back when we have a structured request again
2672        // assert_eq!(
2673        //     request.excerpt_path.as_ref(),
2674        //     Path::new(path!("root/foo.md"))
2675        // );
2676        // assert_eq!(
2677        //     request.cursor_point,
2678        //     Point {
2679        //         line: Line(1),
2680        //         column: 3
2681        //     }
2682        // );
2683
2684        respond_tx
2685            .send(model_response(indoc! { r"
2686                --- a/root/foo.md
2687                +++ b/root/foo.md
2688                @@ ... @@
2689                 Hello!
2690                -How
2691                +How are you?
2692                 Bye
2693            "}))
2694            .unwrap();
2695
2696        let prediction = prediction_task.await.unwrap().unwrap();
2697
2698        assert_eq!(prediction.edits.len(), 1);
2699        assert_eq!(
2700            prediction.edits[0].0.to_point(&snapshot).start,
2701            language::Point::new(1, 3)
2702        );
2703        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2704    }
2705
2706    #[gpui::test]
2707    async fn test_request_events(cx: &mut TestAppContext) {
2708        let (zeta, mut req_rx) = init_test(cx);
2709        let fs = FakeFs::new(cx.executor());
2710        fs.insert_tree(
2711            "/root",
2712            json!({
2713                "foo.md": "Hello!\n\nBye\n"
2714            }),
2715        )
2716        .await;
2717        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2718
2719        let buffer = project
2720            .update(cx, |project, cx| {
2721                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2722                project.open_buffer(path, cx)
2723            })
2724            .await
2725            .unwrap();
2726
2727        zeta.update(cx, |zeta, cx| {
2728            zeta.register_buffer(&buffer, &project, cx);
2729        });
2730
2731        buffer.update(cx, |buffer, cx| {
2732            buffer.edit(vec![(7..7, "How")], None, cx);
2733        });
2734
2735        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2736        let position = snapshot.anchor_before(language::Point::new(1, 3));
2737
2738        let prediction_task = zeta.update(cx, |zeta, cx| {
2739            zeta.request_prediction(&project, &buffer, position, cx)
2740        });
2741
2742        let (request, respond_tx) = req_rx.next().await.unwrap();
2743
2744        let prompt = prompt_from_request(&request);
2745        assert!(
2746            prompt.contains(indoc! {"
2747            --- a/root/foo.md
2748            +++ b/root/foo.md
2749            @@ -1,3 +1,3 @@
2750             Hello!
2751            -
2752            +How
2753             Bye
2754        "}),
2755            "{prompt}"
2756        );
2757
2758        respond_tx
2759            .send(model_response(indoc! {r#"
2760                --- a/root/foo.md
2761                +++ b/root/foo.md
2762                @@ ... @@
2763                 Hello!
2764                -How
2765                +How are you?
2766                 Bye
2767            "#}))
2768            .unwrap();
2769
2770        let prediction = prediction_task.await.unwrap().unwrap();
2771
2772        assert_eq!(prediction.edits.len(), 1);
2773        assert_eq!(
2774            prediction.edits[0].0.to_point(&snapshot).start,
2775            language::Point::new(1, 3)
2776        );
2777        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2778    }
2779
2780    // Skipped until we start including diagnostics in prompt
2781    // #[gpui::test]
2782    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2783    //     let (zeta, mut req_rx) = init_test(cx);
2784    //     let fs = FakeFs::new(cx.executor());
2785    //     fs.insert_tree(
2786    //         "/root",
2787    //         json!({
2788    //             "foo.md": "Hello!\nBye"
2789    //         }),
2790    //     )
2791    //     .await;
2792    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2793
2794    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2795    //     let diagnostic = lsp::Diagnostic {
2796    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2797    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2798    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2799    //         ..Default::default()
2800    //     };
2801
2802    //     project.update(cx, |project, cx| {
2803    //         project.lsp_store().update(cx, |lsp_store, cx| {
2804    //             // Create some diagnostics
2805    //             lsp_store
2806    //                 .update_diagnostics(
2807    //                     LanguageServerId(0),
2808    //                     lsp::PublishDiagnosticsParams {
2809    //                         uri: path_to_buffer_uri.clone(),
2810    //                         diagnostics: vec![diagnostic],
2811    //                         version: None,
2812    //                     },
2813    //                     None,
2814    //                     language::DiagnosticSourceKind::Pushed,
2815    //                     &[],
2816    //                     cx,
2817    //                 )
2818    //                 .unwrap();
2819    //         });
2820    //     });
2821
2822    //     let buffer = project
2823    //         .update(cx, |project, cx| {
2824    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2825    //             project.open_buffer(path, cx)
2826    //         })
2827    //         .await
2828    //         .unwrap();
2829
2830    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2831    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2832
2833    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2834    //         zeta.request_prediction(&project, &buffer, position, cx)
2835    //     });
2836
2837    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2838
2839    //     assert_eq!(request.diagnostic_groups.len(), 1);
2840    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2841    //         .unwrap();
2842    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2843    //     assert_eq!(
2844    //         value,
2845    //         json!({
2846    //             "entries": [{
2847    //                 "range": {
2848    //                     "start": 8,
2849    //                     "end": 10
2850    //                 },
2851    //                 "diagnostic": {
2852    //                     "source": null,
2853    //                     "code": null,
2854    //                     "code_description": null,
2855    //                     "severity": 1,
2856    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2857    //                     "markdown": null,
2858    //                     "group_id": 0,
2859    //                     "is_primary": true,
2860    //                     "is_disk_based": false,
2861    //                     "is_unnecessary": false,
2862    //                     "source_kind": "Pushed",
2863    //                     "data": null,
2864    //                     "underline": true
2865    //                 }
2866    //             }],
2867    //             "primary_ix": 0
2868    //         })
2869    //     );
2870    // }
2871
2872    fn model_response(text: &str) -> open_ai::Response {
2873        open_ai::Response {
2874            id: Uuid::new_v4().to_string(),
2875            object: "response".into(),
2876            created: 0,
2877            model: "model".into(),
2878            choices: vec![open_ai::Choice {
2879                index: 0,
2880                message: open_ai::RequestMessage::Assistant {
2881                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2882                    tool_calls: vec![],
2883                },
2884                finish_reason: None,
2885            }],
2886            usage: Usage {
2887                prompt_tokens: 0,
2888                completion_tokens: 0,
2889                total_tokens: 0,
2890            },
2891        }
2892    }
2893
2894    fn prompt_from_request(request: &open_ai::Request) -> &str {
2895        assert_eq!(request.messages.len(), 1);
2896        let open_ai::RequestMessage::User {
2897            content: open_ai::MessageContent::Plain(content),
2898            ..
2899        } = &request.messages[0]
2900        else {
2901            panic!(
2902                "Request does not have single user message of type Plain. {:#?}",
2903                request
2904            );
2905        };
2906        content
2907    }
2908
2909    fn init_test(
2910        cx: &mut TestAppContext,
2911    ) -> (
2912        Entity<Zeta>,
2913        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2914    ) {
2915        cx.update(move |cx| {
2916            let settings_store = SettingsStore::test(cx);
2917            cx.set_global(settings_store);
2918            zlog::init_test();
2919
2920            let (req_tx, req_rx) = mpsc::unbounded();
2921
2922            let http_client = FakeHttpClient::create({
2923                move |req| {
2924                    let uri = req.uri().path().to_string();
2925                    let mut body = req.into_body();
2926                    let req_tx = req_tx.clone();
2927                    async move {
2928                        let resp = match uri.as_str() {
2929                            "/client/llm_tokens" => serde_json::to_string(&json!({
2930                                "token": "test"
2931                            }))
2932                            .unwrap(),
2933                            "/predict_edits/raw" => {
2934                                let mut buf = Vec::new();
2935                                body.read_to_end(&mut buf).await.ok();
2936                                let req = serde_json::from_slice(&buf).unwrap();
2937
2938                                let (res_tx, res_rx) = oneshot::channel();
2939                                req_tx.unbounded_send((req, res_tx)).unwrap();
2940                                serde_json::to_string(&res_rx.await?).unwrap()
2941                            }
2942                            _ => {
2943                                panic!("Unexpected path: {}", uri)
2944                            }
2945                        };
2946
2947                        Ok(Response::builder().body(resp.into()).unwrap())
2948                    }
2949                }
2950            });
2951
2952            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2953            client.cloud_client().set_credentials(1, "test".into());
2954
2955            language_model::init(client.clone(), cx);
2956
2957            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2958            let zeta = Zeta::global(&client, &user_store, cx);
2959
2960            (zeta, req_rx)
2961        })
2962    }
2963}