zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50mod xml_edits;
  51
  52use crate::merge_excerpts::merge_excerpts;
  53use crate::prediction::EditPrediction;
  54pub use crate::prediction::EditPredictionId;
  55pub use provider::ZetaEditPredictionProvider;
  56
  57/// Maximum number of events to track.
  58const MAX_EVENT_COUNT: usize = 16;
  59
  60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  61    max_bytes: 512,
  62    min_bytes: 128,
  63    target_before_cursor_over_total_bytes: 0.5,
  64};
  65
  66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  67    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  68
  69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  70    excerpt: DEFAULT_EXCERPT_OPTIONS,
  71};
  72
  73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  74    EditPredictionContextOptions {
  75        use_imports: true,
  76        max_retrieved_declarations: 0,
  77        excerpt: DEFAULT_EXCERPT_OPTIONS,
  78        score: EditPredictionScoreOptions {
  79            omit_excerpt_overlaps: true,
  80        },
  81    };
  82
  83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  84    context: DEFAULT_CONTEXT_OPTIONS,
  85    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  86    max_diagnostic_bytes: 2048,
  87    prompt_format: PromptFormat::DEFAULT,
  88    file_indexing_parallelism: 1,
  89    buffer_change_grouping_interval: Duration::from_secs(1),
  90};
  91
  92static USE_OLLAMA: LazyLock<bool> =
  93    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  94static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  95    env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
  96        "qwen3-coder:30b".to_string()
  97    } else {
  98        "yqvev8r3".to_string()
  99    })
 100});
 101static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 102    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 103        if *USE_OLLAMA {
 104            Some("http://localhost:11434/v1/chat/completions".into())
 105        } else {
 106            None
 107        }
 108    })
 109});
 110
 111pub struct Zeta2FeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115
 116    fn enabled_for_staff() -> bool {
 117        false
 118    }
 119}
 120
 121#[derive(Clone)]
 122struct ZetaGlobal(Entity<Zeta>);
 123
 124impl Global for ZetaGlobal {}
 125
 126pub struct Zeta {
 127    client: Arc<Client>,
 128    user_store: Entity<UserStore>,
 129    llm_token: LlmApiToken,
 130    _llm_token_subscription: Subscription,
 131    projects: HashMap<EntityId, ZetaProject>,
 132    options: ZetaOptions,
 133    update_required: bool,
 134    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 135    #[cfg(feature = "eval-support")]
 136    eval_cache: Option<Arc<dyn EvalCache>>,
 137}
 138
 139#[derive(Debug, Clone, PartialEq)]
 140pub struct ZetaOptions {
 141    pub context: ContextMode,
 142    pub max_prompt_bytes: usize,
 143    pub max_diagnostic_bytes: usize,
 144    pub prompt_format: predict_edits_v3::PromptFormat,
 145    pub file_indexing_parallelism: usize,
 146    pub buffer_change_grouping_interval: Duration,
 147}
 148
 149#[derive(Debug, Clone, PartialEq)]
 150pub enum ContextMode {
 151    Agentic(AgenticContextOptions),
 152    Syntax(EditPredictionContextOptions),
 153}
 154
 155#[derive(Debug, Clone, PartialEq)]
 156pub struct AgenticContextOptions {
 157    pub excerpt: EditPredictionExcerptOptions,
 158}
 159
 160impl ContextMode {
 161    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 162        match self {
 163            ContextMode::Agentic(options) => &options.excerpt,
 164            ContextMode::Syntax(options) => &options.excerpt,
 165        }
 166    }
 167}
 168
 169#[derive(Debug)]
 170pub enum ZetaDebugInfo {
 171    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 172    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 173    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 174    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 175    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 176}
 177
 178#[derive(Debug)]
 179pub struct ZetaContextRetrievalStartedDebugInfo {
 180    pub project: Entity<Project>,
 181    pub timestamp: Instant,
 182    pub search_prompt: String,
 183}
 184
 185#[derive(Debug)]
 186pub struct ZetaContextRetrievalDebugInfo {
 187    pub project: Entity<Project>,
 188    pub timestamp: Instant,
 189}
 190
 191#[derive(Debug)]
 192pub struct ZetaEditPredictionDebugInfo {
 193    pub request: predict_edits_v3::PredictEditsRequest,
 194    pub retrieval_time: TimeDelta,
 195    pub buffer: WeakEntity<Buffer>,
 196    pub position: language::Anchor,
 197    pub local_prompt: Result<String, String>,
 198    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 199}
 200
 201#[derive(Debug)]
 202pub struct ZetaSearchQueryDebugInfo {
 203    pub project: Entity<Project>,
 204    pub timestamp: Instant,
 205    pub search_queries: Vec<SearchToolQuery>,
 206}
 207
 208pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 209
 210struct ZetaProject {
 211    syntax_index: Entity<SyntaxIndex>,
 212    events: VecDeque<Event>,
 213    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 214    current_prediction: Option<CurrentEditPrediction>,
 215    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 216    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 217    refresh_context_debounce_task: Option<Task<Option<()>>>,
 218    refresh_context_timestamp: Option<Instant>,
 219}
 220
 221#[derive(Debug, Clone)]
 222struct CurrentEditPrediction {
 223    pub requested_by_buffer_id: EntityId,
 224    pub prediction: EditPrediction,
 225}
 226
 227impl CurrentEditPrediction {
 228    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 229        let Some(new_edits) = self
 230            .prediction
 231            .interpolate(&self.prediction.buffer.read(cx))
 232        else {
 233            return false;
 234        };
 235
 236        if self.prediction.buffer != old_prediction.prediction.buffer {
 237            return true;
 238        }
 239
 240        let Some(old_edits) = old_prediction
 241            .prediction
 242            .interpolate(&old_prediction.prediction.buffer.read(cx))
 243        else {
 244            return true;
 245        };
 246
 247        // This reduces the occurrence of UI thrash from replacing edits
 248        //
 249        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 250        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 251            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 252            && old_edits.len() == 1
 253            && new_edits.len() == 1
 254        {
 255            let (old_range, old_text) = &old_edits[0];
 256            let (new_range, new_text) = &new_edits[0];
 257            new_range == old_range && new_text.starts_with(old_text.as_ref())
 258        } else {
 259            true
 260        }
 261    }
 262}
 263
 264/// A prediction from the perspective of a buffer.
 265#[derive(Debug)]
 266enum BufferEditPrediction<'a> {
 267    Local { prediction: &'a EditPrediction },
 268    Jump { prediction: &'a EditPrediction },
 269}
 270
 271struct RegisteredBuffer {
 272    snapshot: BufferSnapshot,
 273    _subscriptions: [gpui::Subscription; 2],
 274}
 275
 276#[derive(Clone)]
 277pub enum Event {
 278    BufferChange {
 279        old_snapshot: BufferSnapshot,
 280        new_snapshot: BufferSnapshot,
 281        timestamp: Instant,
 282    },
 283}
 284
 285impl Event {
 286    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 287        match self {
 288            Event::BufferChange {
 289                old_snapshot,
 290                new_snapshot,
 291                ..
 292            } => {
 293                let path = new_snapshot.file().map(|f| f.full_path(cx));
 294
 295                let old_path = old_snapshot.file().and_then(|f| {
 296                    let old_path = f.full_path(cx);
 297                    if Some(&old_path) != path.as_ref() {
 298                        Some(old_path)
 299                    } else {
 300                        None
 301                    }
 302                });
 303
 304                // TODO [zeta2] move to bg?
 305                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 306
 307                if path == old_path && diff.is_empty() {
 308                    None
 309                } else {
 310                    Some(predict_edits_v3::Event::BufferChange {
 311                        old_path,
 312                        path,
 313                        diff,
 314                        //todo: Actually detect if this edit was predicted or not
 315                        predicted: false,
 316                    })
 317                }
 318            }
 319        }
 320    }
 321}
 322
 323impl Zeta {
 324    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 325        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 326    }
 327
 328    pub fn global(
 329        client: &Arc<Client>,
 330        user_store: &Entity<UserStore>,
 331        cx: &mut App,
 332    ) -> Entity<Self> {
 333        cx.try_global::<ZetaGlobal>()
 334            .map(|global| global.0.clone())
 335            .unwrap_or_else(|| {
 336                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 337                cx.set_global(ZetaGlobal(zeta.clone()));
 338                zeta
 339            })
 340    }
 341
 342    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 343        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 344
 345        Self {
 346            projects: HashMap::default(),
 347            client,
 348            user_store,
 349            options: DEFAULT_OPTIONS,
 350            llm_token: LlmApiToken::default(),
 351            _llm_token_subscription: cx.subscribe(
 352                &refresh_llm_token_listener,
 353                |this, _listener, _event, cx| {
 354                    let client = this.client.clone();
 355                    let llm_token = this.llm_token.clone();
 356                    cx.spawn(async move |_this, _cx| {
 357                        llm_token.refresh(&client).await?;
 358                        anyhow::Ok(())
 359                    })
 360                    .detach_and_log_err(cx);
 361                },
 362            ),
 363            update_required: false,
 364            debug_tx: None,
 365            #[cfg(feature = "eval-support")]
 366            eval_cache: None,
 367        }
 368    }
 369
 370    #[cfg(feature = "eval-support")]
 371    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 372        self.eval_cache = Some(cache);
 373    }
 374
 375    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 376        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 377        self.debug_tx = Some(debug_watch_tx);
 378        debug_watch_rx
 379    }
 380
 381    pub fn options(&self) -> &ZetaOptions {
 382        &self.options
 383    }
 384
 385    pub fn set_options(&mut self, options: ZetaOptions) {
 386        self.options = options;
 387    }
 388
 389    pub fn clear_history(&mut self) {
 390        for zeta_project in self.projects.values_mut() {
 391            zeta_project.events.clear();
 392        }
 393    }
 394
 395    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 396        self.projects
 397            .get(&project.entity_id())
 398            .map(|project| project.events.iter())
 399            .into_iter()
 400            .flatten()
 401    }
 402
 403    pub fn context_for_project(
 404        &self,
 405        project: &Entity<Project>,
 406    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 407        self.projects
 408            .get(&project.entity_id())
 409            .and_then(|project| {
 410                Some(
 411                    project
 412                        .context
 413                        .as_ref()?
 414                        .iter()
 415                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 416                )
 417            })
 418            .into_iter()
 419            .flatten()
 420    }
 421
 422    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 423        self.user_store.read(cx).edit_prediction_usage()
 424    }
 425
 426    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 427        self.get_or_init_zeta_project(project, cx);
 428    }
 429
 430    pub fn register_buffer(
 431        &mut self,
 432        buffer: &Entity<Buffer>,
 433        project: &Entity<Project>,
 434        cx: &mut Context<Self>,
 435    ) {
 436        let zeta_project = self.get_or_init_zeta_project(project, cx);
 437        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 438    }
 439
 440    fn get_or_init_zeta_project(
 441        &mut self,
 442        project: &Entity<Project>,
 443        cx: &mut App,
 444    ) -> &mut ZetaProject {
 445        self.projects
 446            .entry(project.entity_id())
 447            .or_insert_with(|| ZetaProject {
 448                syntax_index: cx.new(|cx| {
 449                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 450                }),
 451                events: VecDeque::new(),
 452                registered_buffers: HashMap::default(),
 453                current_prediction: None,
 454                context: None,
 455                refresh_context_task: None,
 456                refresh_context_debounce_task: None,
 457                refresh_context_timestamp: None,
 458            })
 459    }
 460
 461    fn register_buffer_impl<'a>(
 462        zeta_project: &'a mut ZetaProject,
 463        buffer: &Entity<Buffer>,
 464        project: &Entity<Project>,
 465        cx: &mut Context<Self>,
 466    ) -> &'a mut RegisteredBuffer {
 467        let buffer_id = buffer.entity_id();
 468        match zeta_project.registered_buffers.entry(buffer_id) {
 469            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 470            hash_map::Entry::Vacant(entry) => {
 471                let snapshot = buffer.read(cx).snapshot();
 472                let project_entity_id = project.entity_id();
 473                entry.insert(RegisteredBuffer {
 474                    snapshot,
 475                    _subscriptions: [
 476                        cx.subscribe(buffer, {
 477                            let project = project.downgrade();
 478                            move |this, buffer, event, cx| {
 479                                if let language::BufferEvent::Edited = event
 480                                    && let Some(project) = project.upgrade()
 481                                {
 482                                    this.report_changes_for_buffer(&buffer, &project, cx);
 483                                }
 484                            }
 485                        }),
 486                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 487                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 488                            else {
 489                                return;
 490                            };
 491                            zeta_project.registered_buffers.remove(&buffer_id);
 492                        }),
 493                    ],
 494                })
 495            }
 496        }
 497    }
 498
 499    fn report_changes_for_buffer(
 500        &mut self,
 501        buffer: &Entity<Buffer>,
 502        project: &Entity<Project>,
 503        cx: &mut Context<Self>,
 504    ) -> BufferSnapshot {
 505        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 506        let zeta_project = self.get_or_init_zeta_project(project, cx);
 507        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 508
 509        let new_snapshot = buffer.read(cx).snapshot();
 510        if new_snapshot.version != registered_buffer.snapshot.version {
 511            let old_snapshot =
 512                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 513            Self::push_event(
 514                zeta_project,
 515                buffer_change_grouping_interval,
 516                Event::BufferChange {
 517                    old_snapshot,
 518                    new_snapshot: new_snapshot.clone(),
 519                    timestamp: Instant::now(),
 520                },
 521            );
 522        }
 523
 524        new_snapshot
 525    }
 526
 527    fn push_event(
 528        zeta_project: &mut ZetaProject,
 529        buffer_change_grouping_interval: Duration,
 530        event: Event,
 531    ) {
 532        let events = &mut zeta_project.events;
 533
 534        if buffer_change_grouping_interval > Duration::ZERO
 535            && let Some(Event::BufferChange {
 536                new_snapshot: last_new_snapshot,
 537                timestamp: last_timestamp,
 538                ..
 539            }) = events.back_mut()
 540        {
 541            // Coalesce edits for the same buffer when they happen one after the other.
 542            let Event::BufferChange {
 543                old_snapshot,
 544                new_snapshot,
 545                timestamp,
 546            } = &event;
 547
 548            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 549                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 550                && old_snapshot.version == last_new_snapshot.version
 551            {
 552                *last_new_snapshot = new_snapshot.clone();
 553                *last_timestamp = *timestamp;
 554                return;
 555            }
 556        }
 557
 558        if events.len() >= MAX_EVENT_COUNT {
 559            // These are halved instead of popping to improve prompt caching.
 560            events.drain(..MAX_EVENT_COUNT / 2);
 561        }
 562
 563        events.push_back(event);
 564    }
 565
 566    fn current_prediction_for_buffer(
 567        &self,
 568        buffer: &Entity<Buffer>,
 569        project: &Entity<Project>,
 570        cx: &App,
 571    ) -> Option<BufferEditPrediction<'_>> {
 572        let project_state = self.projects.get(&project.entity_id())?;
 573
 574        let CurrentEditPrediction {
 575            requested_by_buffer_id,
 576            prediction,
 577        } = project_state.current_prediction.as_ref()?;
 578
 579        if prediction.targets_buffer(buffer.read(cx)) {
 580            Some(BufferEditPrediction::Local { prediction })
 581        } else if *requested_by_buffer_id == buffer.entity_id() {
 582            Some(BufferEditPrediction::Jump { prediction })
 583        } else {
 584            None
 585        }
 586    }
 587
 588    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 589        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 590            return;
 591        };
 592
 593        let Some(prediction) = project_state.current_prediction.take() else {
 594            return;
 595        };
 596        let request_id = prediction.prediction.id.to_string();
 597
 598        let client = self.client.clone();
 599        let llm_token = self.llm_token.clone();
 600        let app_version = AppVersion::global(cx);
 601        cx.spawn(async move |this, cx| {
 602            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 603                http_client::Url::parse(&predict_edits_url)?
 604            } else {
 605                client
 606                    .http_client()
 607                    .build_zed_llm_url("/predict_edits/accept", &[])?
 608            };
 609
 610            let response = cx
 611                .background_spawn(Self::send_api_request::<()>(
 612                    move |builder| {
 613                        let req = builder.uri(url.as_ref()).body(
 614                            serde_json::to_string(&AcceptEditPredictionBody {
 615                                request_id: request_id.clone(),
 616                            })?
 617                            .into(),
 618                        );
 619                        Ok(req?)
 620                    },
 621                    client,
 622                    llm_token,
 623                    app_version,
 624                ))
 625                .await;
 626
 627            Self::handle_api_response(&this, response, cx)?;
 628            anyhow::Ok(())
 629        })
 630        .detach_and_log_err(cx);
 631    }
 632
 633    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 634        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 635            project_state.current_prediction.take();
 636        };
 637    }
 638
 639    pub fn refresh_prediction(
 640        &mut self,
 641        project: &Entity<Project>,
 642        buffer: &Entity<Buffer>,
 643        position: language::Anchor,
 644        cx: &mut Context<Self>,
 645    ) -> Task<Result<()>> {
 646        let request_task = self.request_prediction(project, buffer, position, cx);
 647        let buffer = buffer.clone();
 648        let project = project.clone();
 649
 650        cx.spawn(async move |this, cx| {
 651            if let Some(prediction) = request_task.await? {
 652                this.update(cx, |this, cx| {
 653                    let project_state = this
 654                        .projects
 655                        .get_mut(&project.entity_id())
 656                        .context("Project not found")?;
 657
 658                    let new_prediction = CurrentEditPrediction {
 659                        requested_by_buffer_id: buffer.entity_id(),
 660                        prediction: prediction,
 661                    };
 662
 663                    if project_state
 664                        .current_prediction
 665                        .as_ref()
 666                        .is_none_or(|old_prediction| {
 667                            new_prediction.should_replace_prediction(&old_prediction, cx)
 668                        })
 669                    {
 670                        project_state.current_prediction = Some(new_prediction);
 671                    }
 672                    anyhow::Ok(())
 673                })??;
 674            }
 675            Ok(())
 676        })
 677    }
 678
 679    pub fn request_prediction(
 680        &mut self,
 681        project: &Entity<Project>,
 682        active_buffer: &Entity<Buffer>,
 683        position: language::Anchor,
 684        cx: &mut Context<Self>,
 685    ) -> Task<Result<Option<EditPrediction>>> {
 686        let project_state = self.projects.get(&project.entity_id());
 687
 688        let index_state = project_state.map(|state| {
 689            state
 690                .syntax_index
 691                .read_with(cx, |index, _cx| index.state().clone())
 692        });
 693        let options = self.options.clone();
 694        let active_snapshot = active_buffer.read(cx).snapshot();
 695        let Some(excerpt_path) = active_snapshot
 696            .file()
 697            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 698        else {
 699            return Task::ready(Err(anyhow!("No file path for excerpt")));
 700        };
 701        let client = self.client.clone();
 702        let llm_token = self.llm_token.clone();
 703        let app_version = AppVersion::global(cx);
 704        let worktree_snapshots = project
 705            .read(cx)
 706            .worktrees(cx)
 707            .map(|worktree| worktree.read(cx).snapshot())
 708            .collect::<Vec<_>>();
 709        let debug_tx = self.debug_tx.clone();
 710
 711        let events = project_state
 712            .map(|state| {
 713                state
 714                    .events
 715                    .iter()
 716                    .filter_map(|event| event.to_request_event(cx))
 717                    .collect::<Vec<_>>()
 718            })
 719            .unwrap_or_default();
 720
 721        let diagnostics = active_snapshot.diagnostic_sets().clone();
 722
 723        let parent_abs_path =
 724            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 725                let mut path = f.worktree.read(cx).absolutize(&f.path);
 726                if path.pop() { Some(path) } else { None }
 727            });
 728
 729        // TODO data collection
 730        let can_collect_data = cx.is_staff();
 731
 732        let empty_context_files = HashMap::default();
 733        let context_files = project_state
 734            .and_then(|project_state| project_state.context.as_ref())
 735            .unwrap_or(&empty_context_files);
 736
 737        #[cfg(feature = "eval-support")]
 738        let parsed_fut = futures::future::join_all(
 739            context_files
 740                .keys()
 741                .map(|buffer| buffer.read(cx).parsing_idle()),
 742        );
 743
 744        let mut included_files = context_files
 745            .iter()
 746            .filter_map(|(buffer_entity, ranges)| {
 747                let buffer = buffer_entity.read(cx);
 748                Some((
 749                    buffer_entity.clone(),
 750                    buffer.snapshot(),
 751                    buffer.file()?.full_path(cx).into(),
 752                    ranges.clone(),
 753                ))
 754            })
 755            .collect::<Vec<_>>();
 756
 757        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
 758            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
 759        });
 760
 761        #[cfg(feature = "eval-support")]
 762        let eval_cache = self.eval_cache.clone();
 763
 764        let request_task = cx.background_spawn({
 765            let active_buffer = active_buffer.clone();
 766            async move {
 767                #[cfg(feature = "eval-support")]
 768                parsed_fut.await;
 769
 770                let index_state = if let Some(index_state) = index_state {
 771                    Some(index_state.lock_owned().await)
 772                } else {
 773                    None
 774                };
 775
 776                let cursor_offset = position.to_offset(&active_snapshot);
 777                let cursor_point = cursor_offset.to_point(&active_snapshot);
 778
 779                let before_retrieval = chrono::Utc::now();
 780
 781                let (diagnostic_groups, diagnostic_groups_truncated) =
 782                    Self::gather_nearby_diagnostics(
 783                        cursor_offset,
 784                        &diagnostics,
 785                        &active_snapshot,
 786                        options.max_diagnostic_bytes,
 787                    );
 788
 789                let cloud_request = match options.context {
 790                    ContextMode::Agentic(context_options) => {
 791                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 792                            cursor_point,
 793                            &active_snapshot,
 794                            &context_options.excerpt,
 795                            index_state.as_deref(),
 796                        ) else {
 797                            return Ok((None, None));
 798                        };
 799
 800                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 801                            ..active_snapshot.anchor_before(excerpt.range.end);
 802
 803                        if let Some(buffer_ix) =
 804                            included_files.iter().position(|(_, snapshot, _, _)| {
 805                                snapshot.remote_id() == active_snapshot.remote_id()
 806                            })
 807                        {
 808                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 809                            let range_ix = ranges
 810                                .binary_search_by(|probe| {
 811                                    probe
 812                                        .start
 813                                        .cmp(&excerpt_anchor_range.start, buffer)
 814                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 815                                })
 816                                .unwrap_or_else(|ix| ix);
 817
 818                            ranges.insert(range_ix, excerpt_anchor_range);
 819                            let last_ix = included_files.len() - 1;
 820                            included_files.swap(buffer_ix, last_ix);
 821                        } else {
 822                            included_files.push((
 823                                active_buffer.clone(),
 824                                active_snapshot,
 825                                excerpt_path.clone(),
 826                                vec![excerpt_anchor_range],
 827                            ));
 828                        }
 829
 830                        let included_files = included_files
 831                            .iter()
 832                            .map(|(_, snapshot, path, ranges)| {
 833                                let excerpts = merge_excerpts(
 834                                    &snapshot,
 835                                    ranges.iter().map(|range| {
 836                                        let point_range = range.to_point(&snapshot);
 837                                        Line(point_range.start.row)..Line(point_range.end.row)
 838                                    }),
 839                                );
 840                                predict_edits_v3::IncludedFile {
 841                                    path: path.clone(),
 842                                    max_row: Line(snapshot.max_point().row),
 843                                    excerpts,
 844                                }
 845                            })
 846                            .collect::<Vec<_>>();
 847
 848                        predict_edits_v3::PredictEditsRequest {
 849                            excerpt_path,
 850                            excerpt: String::new(),
 851                            excerpt_line_range: Line(0)..Line(0),
 852                            excerpt_range: 0..0,
 853                            cursor_point: predict_edits_v3::Point {
 854                                line: predict_edits_v3::Line(cursor_point.row),
 855                                column: cursor_point.column,
 856                            },
 857                            included_files,
 858                            referenced_declarations: vec![],
 859                            events,
 860                            can_collect_data,
 861                            diagnostic_groups,
 862                            diagnostic_groups_truncated,
 863                            debug_info: debug_tx.is_some(),
 864                            prompt_max_bytes: Some(options.max_prompt_bytes),
 865                            prompt_format: options.prompt_format,
 866                            // TODO [zeta2]
 867                            signatures: vec![],
 868                            excerpt_parent: None,
 869                            git_info: None,
 870                        }
 871                    }
 872                    ContextMode::Syntax(context_options) => {
 873                        let Some(context) = EditPredictionContext::gather_context(
 874                            cursor_point,
 875                            &active_snapshot,
 876                            parent_abs_path.as_deref(),
 877                            &context_options,
 878                            index_state.as_deref(),
 879                        ) else {
 880                            return Ok((None, None));
 881                        };
 882
 883                        make_syntax_context_cloud_request(
 884                            excerpt_path,
 885                            context,
 886                            events,
 887                            can_collect_data,
 888                            diagnostic_groups,
 889                            diagnostic_groups_truncated,
 890                            None,
 891                            debug_tx.is_some(),
 892                            &worktree_snapshots,
 893                            index_state.as_deref(),
 894                            Some(options.max_prompt_bytes),
 895                            options.prompt_format,
 896                        )
 897                    }
 898                };
 899
 900                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 901
 902                let retrieval_time = chrono::Utc::now() - before_retrieval;
 903
 904                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 905                    let (response_tx, response_rx) = oneshot::channel();
 906
 907                    debug_tx
 908                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 909                            ZetaEditPredictionDebugInfo {
 910                                request: cloud_request.clone(),
 911                                retrieval_time,
 912                                buffer: active_buffer.downgrade(),
 913                                local_prompt: match prompt_result.as_ref() {
 914                                    Ok((prompt, _)) => Ok(prompt.clone()),
 915                                    Err(err) => Err(err.to_string()),
 916                                },
 917                                position,
 918                                response_rx,
 919                            },
 920                        ))
 921                        .ok();
 922                    Some(response_tx)
 923                } else {
 924                    None
 925                };
 926
 927                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 928                    if let Some(debug_response_tx) = debug_response_tx {
 929                        debug_response_tx
 930                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 931                            .ok();
 932                    }
 933                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 934                }
 935
 936                let (prompt, _) = prompt_result?;
 937                let request = open_ai::Request {
 938                    model: MODEL_ID.clone(),
 939                    messages: vec![open_ai::RequestMessage::User {
 940                        content: open_ai::MessageContent::Plain(prompt),
 941                    }],
 942                    stream: false,
 943                    max_completion_tokens: None,
 944                    stop: Default::default(),
 945                    temperature: 0.7,
 946                    tool_choice: None,
 947                    parallel_tool_calls: None,
 948                    tools: vec![],
 949                    prompt_cache_key: None,
 950                    reasoning_effort: None,
 951                };
 952
 953                log::trace!("Sending edit prediction request");
 954
 955                let before_request = chrono::Utc::now();
 956                let response = Self::send_raw_llm_request(
 957                    request,
 958                    client,
 959                    llm_token,
 960                    app_version,
 961                    #[cfg(feature = "eval-support")]
 962                    eval_cache,
 963                    #[cfg(feature = "eval-support")]
 964                    EvalCacheEntryKind::Prediction,
 965                )
 966                .await;
 967                let request_time = chrono::Utc::now() - before_request;
 968
 969                log::trace!("Got edit prediction response");
 970
 971                if let Some(debug_response_tx) = debug_response_tx {
 972                    debug_response_tx
 973                        .send((
 974                            response
 975                                .as_ref()
 976                                .map_err(|err| err.to_string())
 977                                .map(|response| response.0.clone()),
 978                            request_time,
 979                        ))
 980                        .ok();
 981                }
 982
 983                let (res, usage) = response?;
 984                let request_id = EditPredictionId(res.id.clone().into());
 985                let Some(mut output_text) = text_from_response(res) else {
 986                    return Ok((None, usage));
 987                };
 988
 989                if output_text.contains(CURSOR_MARKER) {
 990                    log::trace!("Stripping out {CURSOR_MARKER} from response");
 991                    output_text = output_text.replace(CURSOR_MARKER, "");
 992                }
 993
 994                let get_buffer_from_context = |path: &Path| {
 995                    included_files
 996                        .iter()
 997                        .find_map(|(_, buffer, probe_path, ranges)| {
 998                            if probe_path.as_ref() == path {
 999                                Some((buffer, ranges.as_slice()))
1000                            } else {
1001                                None
1002                            }
1003                        })
1004                };
1005
1006                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1007                    PromptFormat::NumLinesUniDiff => {
1008                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1009                    }
1010                    PromptFormat::OldTextNewText => {
1011                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1012                            .await?
1013                    }
1014                    _ => {
1015                        bail!("unsupported prompt format {}", options.prompt_format)
1016                    }
1017                };
1018
1019                let edited_buffer = included_files
1020                    .iter()
1021                    .find_map(|(buffer, snapshot, _, _)| {
1022                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1023                            Some(buffer.clone())
1024                        } else {
1025                            None
1026                        }
1027                    })
1028                    .context("Failed to find buffer in included_buffers")?;
1029
1030                anyhow::Ok((
1031                    Some((
1032                        request_id,
1033                        edited_buffer,
1034                        edited_buffer_snapshot.clone(),
1035                        edits,
1036                    )),
1037                    usage,
1038                ))
1039            }
1040        });
1041
1042        cx.spawn({
1043            async move |this, cx| {
1044                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1045                    Self::handle_api_response(&this, request_task.await, cx)?
1046                else {
1047                    return Ok(None);
1048                };
1049
1050                // TODO telemetry: duration, etc
1051                Ok(
1052                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1053                        .await,
1054                )
1055            }
1056        })
1057    }
1058
1059    async fn send_raw_llm_request(
1060        request: open_ai::Request,
1061        client: Arc<Client>,
1062        llm_token: LlmApiToken,
1063        app_version: SemanticVersion,
1064        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1065        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1066    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1067        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1068            http_client::Url::parse(&predict_edits_url)?
1069        } else {
1070            client
1071                .http_client()
1072                .build_zed_llm_url("/predict_edits/raw", &[])?
1073        };
1074
1075        #[cfg(feature = "eval-support")]
1076        let cache_key = if let Some(cache) = eval_cache {
1077            use collections::FxHasher;
1078            use std::hash::{Hash, Hasher};
1079
1080            let mut hasher = FxHasher::default();
1081            url.hash(&mut hasher);
1082            let request_str = serde_json::to_string_pretty(&request)?;
1083            request_str.hash(&mut hasher);
1084            let hash = hasher.finish();
1085
1086            let key = (eval_cache_kind, hash);
1087            if let Some(response_str) = cache.read(key) {
1088                return Ok((serde_json::from_str(&response_str)?, None));
1089            }
1090
1091            Some((cache, request_str, key))
1092        } else {
1093            None
1094        };
1095
1096        let (response, usage) = Self::send_api_request(
1097            |builder| {
1098                let req = builder
1099                    .uri(url.as_ref())
1100                    .body(serde_json::to_string(&request)?.into());
1101                Ok(req?)
1102            },
1103            client,
1104            llm_token,
1105            app_version,
1106        )
1107        .await?;
1108
1109        #[cfg(feature = "eval-support")]
1110        if let Some((cache, request, key)) = cache_key {
1111            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1112        }
1113
1114        Ok((response, usage))
1115    }
1116
1117    fn handle_api_response<T>(
1118        this: &WeakEntity<Self>,
1119        response: Result<(T, Option<EditPredictionUsage>)>,
1120        cx: &mut gpui::AsyncApp,
1121    ) -> Result<T> {
1122        match response {
1123            Ok((data, usage)) => {
1124                if let Some(usage) = usage {
1125                    this.update(cx, |this, cx| {
1126                        this.user_store.update(cx, |user_store, cx| {
1127                            user_store.update_edit_prediction_usage(usage, cx);
1128                        });
1129                    })
1130                    .ok();
1131                }
1132                Ok(data)
1133            }
1134            Err(err) => {
1135                if err.is::<ZedUpdateRequiredError>() {
1136                    cx.update(|cx| {
1137                        this.update(cx, |this, _cx| {
1138                            this.update_required = true;
1139                        })
1140                        .ok();
1141
1142                        let error_message: SharedString = err.to_string().into();
1143                        show_app_notification(
1144                            NotificationId::unique::<ZedUpdateRequiredError>(),
1145                            cx,
1146                            move |cx| {
1147                                cx.new(|cx| {
1148                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1149                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1150                                })
1151                            },
1152                        );
1153                    })
1154                    .ok();
1155                }
1156                Err(err)
1157            }
1158        }
1159    }
1160
1161    async fn send_api_request<Res>(
1162        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1163        client: Arc<Client>,
1164        llm_token: LlmApiToken,
1165        app_version: SemanticVersion,
1166    ) -> Result<(Res, Option<EditPredictionUsage>)>
1167    where
1168        Res: DeserializeOwned,
1169    {
1170        let http_client = client.http_client();
1171        let mut token = llm_token.acquire(&client).await?;
1172        let mut did_retry = false;
1173
1174        loop {
1175            let request_builder = http_client::Request::builder().method(Method::POST);
1176
1177            let request = build(
1178                request_builder
1179                    .header("Content-Type", "application/json")
1180                    .header("Authorization", format!("Bearer {}", token))
1181                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1182            )?;
1183
1184            let mut response = http_client.send(request).await?;
1185
1186            if let Some(minimum_required_version) = response
1187                .headers()
1188                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1189                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1190            {
1191                anyhow::ensure!(
1192                    app_version >= minimum_required_version,
1193                    ZedUpdateRequiredError {
1194                        minimum_version: minimum_required_version
1195                    }
1196                );
1197            }
1198
1199            if response.status().is_success() {
1200                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1201
1202                let mut body = Vec::new();
1203                response.body_mut().read_to_end(&mut body).await?;
1204                return Ok((serde_json::from_slice(&body)?, usage));
1205            } else if !did_retry
1206                && response
1207                    .headers()
1208                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1209                    .is_some()
1210            {
1211                did_retry = true;
1212                token = llm_token.refresh(&client).await?;
1213            } else {
1214                let mut body = String::new();
1215                response.body_mut().read_to_string(&mut body).await?;
1216                anyhow::bail!(
1217                    "Request failed with status: {:?}\nBody: {}",
1218                    response.status(),
1219                    body
1220                );
1221            }
1222        }
1223    }
1224
1225    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1226    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1227
1228    // Refresh the related excerpts when the user just beguns editing after
1229    // an idle period, and after they pause editing.
1230    fn refresh_context_if_needed(
1231        &mut self,
1232        project: &Entity<Project>,
1233        buffer: &Entity<language::Buffer>,
1234        cursor_position: language::Anchor,
1235        cx: &mut Context<Self>,
1236    ) {
1237        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1238            return;
1239        }
1240
1241        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1242            return;
1243        };
1244
1245        let now = Instant::now();
1246        let was_idle = zeta_project
1247            .refresh_context_timestamp
1248            .map_or(true, |timestamp| {
1249                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1250            });
1251        zeta_project.refresh_context_timestamp = Some(now);
1252        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1253            let buffer = buffer.clone();
1254            let project = project.clone();
1255            async move |this, cx| {
1256                if was_idle {
1257                    log::debug!("refetching edit prediction context after idle");
1258                } else {
1259                    cx.background_executor()
1260                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1261                        .await;
1262                    log::debug!("refetching edit prediction context after pause");
1263                }
1264                this.update(cx, |this, cx| {
1265                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1266
1267                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1268                        zeta_project.refresh_context_task = Some(task.log_err());
1269                    };
1270                })
1271                .ok()
1272            }
1273        }));
1274    }
1275
1276    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1277    // and avoid spawning more than one concurrent task.
1278    pub fn refresh_context(
1279        &mut self,
1280        project: Entity<Project>,
1281        buffer: Entity<language::Buffer>,
1282        cursor_position: language::Anchor,
1283        cx: &mut Context<Self>,
1284    ) -> Task<Result<()>> {
1285        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1286            return Task::ready(anyhow::Ok(()));
1287        };
1288
1289        let ContextMode::Agentic(options) = &self.options().context else {
1290            return Task::ready(anyhow::Ok(()));
1291        };
1292
1293        let snapshot = buffer.read(cx).snapshot();
1294        let cursor_point = cursor_position.to_point(&snapshot);
1295        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1296            cursor_point,
1297            &snapshot,
1298            &options.excerpt,
1299            None,
1300        ) else {
1301            return Task::ready(Ok(()));
1302        };
1303
1304        let app_version = AppVersion::global(cx);
1305        let client = self.client.clone();
1306        let llm_token = self.llm_token.clone();
1307        let debug_tx = self.debug_tx.clone();
1308        let current_file_path: Arc<Path> = snapshot
1309            .file()
1310            .map(|f| f.full_path(cx).into())
1311            .unwrap_or_else(|| Path::new("untitled").into());
1312
1313        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1314            predict_edits_v3::PlanContextRetrievalRequest {
1315                excerpt: cursor_excerpt.text(&snapshot).body,
1316                excerpt_path: current_file_path,
1317                excerpt_line_range: cursor_excerpt.line_range,
1318                cursor_file_max_row: Line(snapshot.max_point().row),
1319                events: zeta_project
1320                    .events
1321                    .iter()
1322                    .filter_map(|ev| ev.to_request_event(cx))
1323                    .collect(),
1324            },
1325        ) {
1326            Ok(prompt) => prompt,
1327            Err(err) => {
1328                return Task::ready(Err(err));
1329            }
1330        };
1331
1332        if let Some(debug_tx) = &debug_tx {
1333            debug_tx
1334                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1335                    ZetaContextRetrievalStartedDebugInfo {
1336                        project: project.clone(),
1337                        timestamp: Instant::now(),
1338                        search_prompt: prompt.clone(),
1339                    },
1340                ))
1341                .ok();
1342        }
1343
1344        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1345            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1346                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1347            );
1348
1349            let description = schema
1350                .get("description")
1351                .and_then(|description| description.as_str())
1352                .unwrap()
1353                .to_string();
1354
1355            (schema.into(), description)
1356        });
1357
1358        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1359
1360        let request = open_ai::Request {
1361            model: MODEL_ID.clone(),
1362            messages: vec![open_ai::RequestMessage::User {
1363                content: open_ai::MessageContent::Plain(prompt),
1364            }],
1365            stream: false,
1366            max_completion_tokens: None,
1367            stop: Default::default(),
1368            temperature: 0.7,
1369            tool_choice: None,
1370            parallel_tool_calls: None,
1371            tools: vec![open_ai::ToolDefinition::Function {
1372                function: FunctionDefinition {
1373                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1374                    description: Some(tool_description),
1375                    parameters: Some(tool_schema),
1376                },
1377            }],
1378            prompt_cache_key: None,
1379            reasoning_effort: None,
1380        };
1381
1382        #[cfg(feature = "eval-support")]
1383        let eval_cache = self.eval_cache.clone();
1384
1385        cx.spawn(async move |this, cx| {
1386            log::trace!("Sending search planning request");
1387            let response = Self::send_raw_llm_request(
1388                request,
1389                client,
1390                llm_token,
1391                app_version,
1392                #[cfg(feature = "eval-support")]
1393                eval_cache.clone(),
1394                #[cfg(feature = "eval-support")]
1395                EvalCacheEntryKind::Context,
1396            )
1397            .await;
1398            let mut response = Self::handle_api_response(&this, response, cx)?;
1399            log::trace!("Got search planning response");
1400
1401            let choice = response
1402                .choices
1403                .pop()
1404                .context("No choices in retrieval response")?;
1405            let open_ai::RequestMessage::Assistant {
1406                content: _,
1407                tool_calls,
1408            } = choice.message
1409            else {
1410                anyhow::bail!("Retrieval response didn't include an assistant message");
1411            };
1412
1413            let mut queries: Vec<SearchToolQuery> = Vec::new();
1414            for tool_call in tool_calls {
1415                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1416                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1417                    log::warn!(
1418                        "Context retrieval response tried to call an unknown tool: {}",
1419                        function.name
1420                    );
1421
1422                    continue;
1423                }
1424
1425                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1426                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1427                queries.extend(input.queries);
1428            }
1429
1430            if let Some(debug_tx) = &debug_tx {
1431                debug_tx
1432                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1433                        ZetaSearchQueryDebugInfo {
1434                            project: project.clone(),
1435                            timestamp: Instant::now(),
1436                            search_queries: queries.clone(),
1437                        },
1438                    ))
1439                    .ok();
1440            }
1441
1442            log::trace!("Running retrieval search: {queries:#?}");
1443
1444            let related_excerpts_result = retrieval_search::run_retrieval_searches(
1445                queries,
1446                project.clone(),
1447                #[cfg(feature = "eval-support")]
1448                eval_cache,
1449                cx,
1450            )
1451            .await;
1452
1453            log::trace!("Search queries executed");
1454
1455            if let Some(debug_tx) = &debug_tx {
1456                debug_tx
1457                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1458                        ZetaContextRetrievalDebugInfo {
1459                            project: project.clone(),
1460                            timestamp: Instant::now(),
1461                        },
1462                    ))
1463                    .ok();
1464            }
1465
1466            this.update(cx, |this, _cx| {
1467                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1468                    return Ok(());
1469                };
1470                zeta_project.refresh_context_task.take();
1471                if let Some(debug_tx) = &this.debug_tx {
1472                    debug_tx
1473                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1474                            ZetaContextRetrievalDebugInfo {
1475                                project,
1476                                timestamp: Instant::now(),
1477                            },
1478                        ))
1479                        .ok();
1480                }
1481                match related_excerpts_result {
1482                    Ok(excerpts) => {
1483                        zeta_project.context = Some(excerpts);
1484                        Ok(())
1485                    }
1486                    Err(error) => Err(error),
1487                }
1488            })?
1489        })
1490    }
1491
1492    pub fn set_context(
1493        &mut self,
1494        project: Entity<Project>,
1495        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1496    ) {
1497        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1498            zeta_project.context = Some(context);
1499        }
1500    }
1501
1502    fn gather_nearby_diagnostics(
1503        cursor_offset: usize,
1504        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1505        snapshot: &BufferSnapshot,
1506        max_diagnostics_bytes: usize,
1507    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1508        // TODO: Could make this more efficient
1509        let mut diagnostic_groups = Vec::new();
1510        for (language_server_id, diagnostics) in diagnostic_sets {
1511            let mut groups = Vec::new();
1512            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1513            diagnostic_groups.extend(
1514                groups
1515                    .into_iter()
1516                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1517            );
1518        }
1519
1520        // sort by proximity to cursor
1521        diagnostic_groups.sort_by_key(|group| {
1522            let range = &group.entries[group.primary_ix].range;
1523            if range.start >= cursor_offset {
1524                range.start - cursor_offset
1525            } else if cursor_offset >= range.end {
1526                cursor_offset - range.end
1527            } else {
1528                (cursor_offset - range.start).min(range.end - cursor_offset)
1529            }
1530        });
1531
1532        let mut results = Vec::new();
1533        let mut diagnostic_groups_truncated = false;
1534        let mut diagnostics_byte_count = 0;
1535        for group in diagnostic_groups {
1536            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1537            diagnostics_byte_count += raw_value.get().len();
1538            if diagnostics_byte_count > max_diagnostics_bytes {
1539                diagnostic_groups_truncated = true;
1540                break;
1541            }
1542            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1543        }
1544
1545        (results, diagnostic_groups_truncated)
1546    }
1547
1548    // TODO: Dedupe with similar code in request_prediction?
1549    pub fn cloud_request_for_zeta_cli(
1550        &mut self,
1551        project: &Entity<Project>,
1552        buffer: &Entity<Buffer>,
1553        position: language::Anchor,
1554        cx: &mut Context<Self>,
1555    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1556        let project_state = self.projects.get(&project.entity_id());
1557
1558        let index_state = project_state.map(|state| {
1559            state
1560                .syntax_index
1561                .read_with(cx, |index, _cx| index.state().clone())
1562        });
1563        let options = self.options.clone();
1564        let snapshot = buffer.read(cx).snapshot();
1565        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1566            return Task::ready(Err(anyhow!("No file path for excerpt")));
1567        };
1568        let worktree_snapshots = project
1569            .read(cx)
1570            .worktrees(cx)
1571            .map(|worktree| worktree.read(cx).snapshot())
1572            .collect::<Vec<_>>();
1573
1574        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1575            let mut path = f.worktree.read(cx).absolutize(&f.path);
1576            if path.pop() { Some(path) } else { None }
1577        });
1578
1579        cx.background_spawn(async move {
1580            let index_state = if let Some(index_state) = index_state {
1581                Some(index_state.lock_owned().await)
1582            } else {
1583                None
1584            };
1585
1586            let cursor_point = position.to_point(&snapshot);
1587
1588            let debug_info = true;
1589            EditPredictionContext::gather_context(
1590                cursor_point,
1591                &snapshot,
1592                parent_abs_path.as_deref(),
1593                match &options.context {
1594                    ContextMode::Agentic(_) => {
1595                        // TODO
1596                        panic!("Llm mode not supported in zeta cli yet");
1597                    }
1598                    ContextMode::Syntax(edit_prediction_context_options) => {
1599                        edit_prediction_context_options
1600                    }
1601                },
1602                index_state.as_deref(),
1603            )
1604            .context("Failed to select excerpt")
1605            .map(|context| {
1606                make_syntax_context_cloud_request(
1607                    excerpt_path.into(),
1608                    context,
1609                    // TODO pass everything
1610                    Vec::new(),
1611                    false,
1612                    Vec::new(),
1613                    false,
1614                    None,
1615                    debug_info,
1616                    &worktree_snapshots,
1617                    index_state.as_deref(),
1618                    Some(options.max_prompt_bytes),
1619                    options.prompt_format,
1620                )
1621            })
1622        })
1623    }
1624
1625    pub fn wait_for_initial_indexing(
1626        &mut self,
1627        project: &Entity<Project>,
1628        cx: &mut App,
1629    ) -> Task<Result<()>> {
1630        let zeta_project = self.get_or_init_zeta_project(project, cx);
1631        zeta_project
1632            .syntax_index
1633            .read(cx)
1634            .wait_for_initial_file_indexing(cx)
1635    }
1636}
1637
1638pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1639    let choice = res.choices.pop()?;
1640    let output_text = match choice.message {
1641        open_ai::RequestMessage::Assistant {
1642            content: Some(open_ai::MessageContent::Plain(content)),
1643            ..
1644        } => content,
1645        open_ai::RequestMessage::Assistant {
1646            content: Some(open_ai::MessageContent::Multipart(mut content)),
1647            ..
1648        } => {
1649            if content.is_empty() {
1650                log::error!("No output from Baseten completion response");
1651                return None;
1652            }
1653
1654            match content.remove(0) {
1655                open_ai::MessagePart::Text { text } => text,
1656                open_ai::MessagePart::Image { .. } => {
1657                    log::error!("Expected text, got an image");
1658                    return None;
1659                }
1660            }
1661        }
1662        _ => {
1663            log::error!("Invalid response message: {:?}", choice.message);
1664            return None;
1665        }
1666    };
1667    Some(output_text)
1668}
1669
1670#[derive(Error, Debug)]
1671#[error(
1672    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1673)]
1674pub struct ZedUpdateRequiredError {
1675    minimum_version: SemanticVersion,
1676}
1677
1678fn make_syntax_context_cloud_request(
1679    excerpt_path: Arc<Path>,
1680    context: EditPredictionContext,
1681    events: Vec<predict_edits_v3::Event>,
1682    can_collect_data: bool,
1683    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1684    diagnostic_groups_truncated: bool,
1685    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1686    debug_info: bool,
1687    worktrees: &Vec<worktree::Snapshot>,
1688    index_state: Option<&SyntaxIndexState>,
1689    prompt_max_bytes: Option<usize>,
1690    prompt_format: PromptFormat,
1691) -> predict_edits_v3::PredictEditsRequest {
1692    let mut signatures = Vec::new();
1693    let mut declaration_to_signature_index = HashMap::default();
1694    let mut referenced_declarations = Vec::new();
1695
1696    for snippet in context.declarations {
1697        let project_entry_id = snippet.declaration.project_entry_id();
1698        let Some(path) = worktrees.iter().find_map(|worktree| {
1699            worktree.entry_for_id(project_entry_id).map(|entry| {
1700                let mut full_path = RelPathBuf::new();
1701                full_path.push(worktree.root_name());
1702                full_path.push(&entry.path);
1703                full_path
1704            })
1705        }) else {
1706            continue;
1707        };
1708
1709        let parent_index = index_state.and_then(|index_state| {
1710            snippet.declaration.parent().and_then(|parent| {
1711                add_signature(
1712                    parent,
1713                    &mut declaration_to_signature_index,
1714                    &mut signatures,
1715                    index_state,
1716                )
1717            })
1718        });
1719
1720        let (text, text_is_truncated) = snippet.declaration.item_text();
1721        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1722            path: path.as_std_path().into(),
1723            text: text.into(),
1724            range: snippet.declaration.item_line_range(),
1725            text_is_truncated,
1726            signature_range: snippet.declaration.signature_range_in_item_text(),
1727            parent_index,
1728            signature_score: snippet.score(DeclarationStyle::Signature),
1729            declaration_score: snippet.score(DeclarationStyle::Declaration),
1730            score_components: snippet.components,
1731        });
1732    }
1733
1734    let excerpt_parent = index_state.and_then(|index_state| {
1735        context
1736            .excerpt
1737            .parent_declarations
1738            .last()
1739            .and_then(|(parent, _)| {
1740                add_signature(
1741                    *parent,
1742                    &mut declaration_to_signature_index,
1743                    &mut signatures,
1744                    index_state,
1745                )
1746            })
1747    });
1748
1749    predict_edits_v3::PredictEditsRequest {
1750        excerpt_path,
1751        excerpt: context.excerpt_text.body,
1752        excerpt_line_range: context.excerpt.line_range,
1753        excerpt_range: context.excerpt.range,
1754        cursor_point: predict_edits_v3::Point {
1755            line: predict_edits_v3::Line(context.cursor_point.row),
1756            column: context.cursor_point.column,
1757        },
1758        referenced_declarations,
1759        included_files: vec![],
1760        signatures,
1761        excerpt_parent,
1762        events,
1763        can_collect_data,
1764        diagnostic_groups,
1765        diagnostic_groups_truncated,
1766        git_info,
1767        debug_info,
1768        prompt_max_bytes,
1769        prompt_format,
1770    }
1771}
1772
1773fn add_signature(
1774    declaration_id: DeclarationId,
1775    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1776    signatures: &mut Vec<Signature>,
1777    index: &SyntaxIndexState,
1778) -> Option<usize> {
1779    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1780        return Some(*signature_index);
1781    }
1782    let Some(parent_declaration) = index.declaration(declaration_id) else {
1783        log::error!("bug: missing parent declaration");
1784        return None;
1785    };
1786    let parent_index = parent_declaration.parent().and_then(|parent| {
1787        add_signature(parent, declaration_to_signature_index, signatures, index)
1788    });
1789    let (text, text_is_truncated) = parent_declaration.signature_text();
1790    let signature_index = signatures.len();
1791    signatures.push(Signature {
1792        text: text.into(),
1793        text_is_truncated,
1794        parent_index,
1795        range: parent_declaration.signature_line_range(),
1796    });
1797    declaration_to_signature_index.insert(declaration_id, signature_index);
1798    Some(signature_index)
1799}
1800
1801#[cfg(feature = "eval-support")]
1802pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1803
1804#[cfg(feature = "eval-support")]
1805#[derive(Debug, Clone, Copy, PartialEq)]
1806pub enum EvalCacheEntryKind {
1807    Context,
1808    Search,
1809    Prediction,
1810}
1811
1812#[cfg(feature = "eval-support")]
1813impl std::fmt::Display for EvalCacheEntryKind {
1814    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1815        match self {
1816            EvalCacheEntryKind::Search => write!(f, "search"),
1817            EvalCacheEntryKind::Context => write!(f, "context"),
1818            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1819        }
1820    }
1821}
1822
1823#[cfg(feature = "eval-support")]
1824pub trait EvalCache: Send + Sync {
1825    fn read(&self, key: EvalCacheKey) -> Option<String>;
1826    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1827}
1828
1829#[cfg(test)]
1830mod tests {
1831    use std::{path::Path, sync::Arc};
1832
1833    use client::UserStore;
1834    use clock::FakeSystemClock;
1835    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1836    use futures::{
1837        AsyncReadExt, StreamExt,
1838        channel::{mpsc, oneshot},
1839    };
1840    use gpui::{
1841        Entity, TestAppContext,
1842        http_client::{FakeHttpClient, Response},
1843        prelude::*,
1844    };
1845    use indoc::indoc;
1846    use language::OffsetRangeExt as _;
1847    use open_ai::Usage;
1848    use pretty_assertions::{assert_eq, assert_matches};
1849    use project::{FakeFs, Project};
1850    use serde_json::json;
1851    use settings::SettingsStore;
1852    use util::path;
1853    use uuid::Uuid;
1854
1855    use crate::{BufferEditPrediction, Zeta};
1856
1857    #[gpui::test]
1858    async fn test_current_state(cx: &mut TestAppContext) {
1859        let (zeta, mut req_rx) = init_test(cx);
1860        let fs = FakeFs::new(cx.executor());
1861        fs.insert_tree(
1862            "/root",
1863            json!({
1864                "1.txt": "Hello!\nHow\nBye\n",
1865                "2.txt": "Hola!\nComo\nAdios\n"
1866            }),
1867        )
1868        .await;
1869        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1870
1871        zeta.update(cx, |zeta, cx| {
1872            zeta.register_project(&project, cx);
1873        });
1874
1875        let buffer1 = project
1876            .update(cx, |project, cx| {
1877                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1878                project.open_buffer(path, cx)
1879            })
1880            .await
1881            .unwrap();
1882        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1883        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1884
1885        // Prediction for current file
1886
1887        let prediction_task = zeta.update(cx, |zeta, cx| {
1888            zeta.refresh_prediction(&project, &buffer1, position, cx)
1889        });
1890        let (_request, respond_tx) = req_rx.next().await.unwrap();
1891
1892        respond_tx
1893            .send(model_response(indoc! {r"
1894                --- a/root/1.txt
1895                +++ b/root/1.txt
1896                @@ ... @@
1897                 Hello!
1898                -How
1899                +How are you?
1900                 Bye
1901            "}))
1902            .unwrap();
1903        prediction_task.await.unwrap();
1904
1905        zeta.read_with(cx, |zeta, cx| {
1906            let prediction = zeta
1907                .current_prediction_for_buffer(&buffer1, &project, cx)
1908                .unwrap();
1909            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1910        });
1911
1912        // Context refresh
1913        let refresh_task = zeta.update(cx, |zeta, cx| {
1914            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1915        });
1916        let (_request, respond_tx) = req_rx.next().await.unwrap();
1917        respond_tx
1918            .send(open_ai::Response {
1919                id: Uuid::new_v4().to_string(),
1920                object: "response".into(),
1921                created: 0,
1922                model: "model".into(),
1923                choices: vec![open_ai::Choice {
1924                    index: 0,
1925                    message: open_ai::RequestMessage::Assistant {
1926                        content: None,
1927                        tool_calls: vec![open_ai::ToolCall {
1928                            id: "search".into(),
1929                            content: open_ai::ToolCallContent::Function {
1930                                function: open_ai::FunctionContent {
1931                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1932                                        .to_string(),
1933                                    arguments: serde_json::to_string(&SearchToolInput {
1934                                        queries: Box::new([SearchToolQuery {
1935                                            glob: "root/2.txt".to_string(),
1936                                            syntax_node: vec![],
1937                                            content: Some(".".into()),
1938                                        }]),
1939                                    })
1940                                    .unwrap(),
1941                                },
1942                            },
1943                        }],
1944                    },
1945                    finish_reason: None,
1946                }],
1947                usage: Usage {
1948                    prompt_tokens: 0,
1949                    completion_tokens: 0,
1950                    total_tokens: 0,
1951                },
1952            })
1953            .unwrap();
1954        refresh_task.await.unwrap();
1955
1956        zeta.update(cx, |zeta, _cx| {
1957            zeta.discard_current_prediction(&project);
1958        });
1959
1960        // Prediction for another file
1961        let prediction_task = zeta.update(cx, |zeta, cx| {
1962            zeta.refresh_prediction(&project, &buffer1, position, cx)
1963        });
1964        let (_request, respond_tx) = req_rx.next().await.unwrap();
1965        respond_tx
1966            .send(model_response(indoc! {r#"
1967                --- a/root/2.txt
1968                +++ b/root/2.txt
1969                 Hola!
1970                -Como
1971                +Como estas?
1972                 Adios
1973            "#}))
1974            .unwrap();
1975        prediction_task.await.unwrap();
1976        zeta.read_with(cx, |zeta, cx| {
1977            let prediction = zeta
1978                .current_prediction_for_buffer(&buffer1, &project, cx)
1979                .unwrap();
1980            assert_matches!(
1981                prediction,
1982                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1983            );
1984        });
1985
1986        let buffer2 = project
1987            .update(cx, |project, cx| {
1988                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1989                project.open_buffer(path, cx)
1990            })
1991            .await
1992            .unwrap();
1993
1994        zeta.read_with(cx, |zeta, cx| {
1995            let prediction = zeta
1996                .current_prediction_for_buffer(&buffer2, &project, cx)
1997                .unwrap();
1998            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1999        });
2000    }
2001
2002    #[gpui::test]
2003    async fn test_simple_request(cx: &mut TestAppContext) {
2004        let (zeta, mut req_rx) = init_test(cx);
2005        let fs = FakeFs::new(cx.executor());
2006        fs.insert_tree(
2007            "/root",
2008            json!({
2009                "foo.md":  "Hello!\nHow\nBye\n"
2010            }),
2011        )
2012        .await;
2013        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2014
2015        let buffer = project
2016            .update(cx, |project, cx| {
2017                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2018                project.open_buffer(path, cx)
2019            })
2020            .await
2021            .unwrap();
2022        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2023        let position = snapshot.anchor_before(language::Point::new(1, 3));
2024
2025        let prediction_task = zeta.update(cx, |zeta, cx| {
2026            zeta.request_prediction(&project, &buffer, position, cx)
2027        });
2028
2029        let (_, respond_tx) = req_rx.next().await.unwrap();
2030
2031        // TODO Put back when we have a structured request again
2032        // assert_eq!(
2033        //     request.excerpt_path.as_ref(),
2034        //     Path::new(path!("root/foo.md"))
2035        // );
2036        // assert_eq!(
2037        //     request.cursor_point,
2038        //     Point {
2039        //         line: Line(1),
2040        //         column: 3
2041        //     }
2042        // );
2043
2044        respond_tx
2045            .send(model_response(indoc! { r"
2046                --- a/root/foo.md
2047                +++ b/root/foo.md
2048                @@ ... @@
2049                 Hello!
2050                -How
2051                +How are you?
2052                 Bye
2053            "}))
2054            .unwrap();
2055
2056        let prediction = prediction_task.await.unwrap().unwrap();
2057
2058        assert_eq!(prediction.edits.len(), 1);
2059        assert_eq!(
2060            prediction.edits[0].0.to_point(&snapshot).start,
2061            language::Point::new(1, 3)
2062        );
2063        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2064    }
2065
2066    #[gpui::test]
2067    async fn test_request_events(cx: &mut TestAppContext) {
2068        let (zeta, mut req_rx) = init_test(cx);
2069        let fs = FakeFs::new(cx.executor());
2070        fs.insert_tree(
2071            "/root",
2072            json!({
2073                "foo.md": "Hello!\n\nBye\n"
2074            }),
2075        )
2076        .await;
2077        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2078
2079        let buffer = project
2080            .update(cx, |project, cx| {
2081                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2082                project.open_buffer(path, cx)
2083            })
2084            .await
2085            .unwrap();
2086
2087        zeta.update(cx, |zeta, cx| {
2088            zeta.register_buffer(&buffer, &project, cx);
2089        });
2090
2091        buffer.update(cx, |buffer, cx| {
2092            buffer.edit(vec![(7..7, "How")], None, cx);
2093        });
2094
2095        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2096        let position = snapshot.anchor_before(language::Point::new(1, 3));
2097
2098        let prediction_task = zeta.update(cx, |zeta, cx| {
2099            zeta.request_prediction(&project, &buffer, position, cx)
2100        });
2101
2102        let (request, respond_tx) = req_rx.next().await.unwrap();
2103
2104        let prompt = prompt_from_request(&request);
2105        assert!(
2106            prompt.contains(indoc! {"
2107            --- a/root/foo.md
2108            +++ b/root/foo.md
2109            @@ -1,3 +1,3 @@
2110             Hello!
2111            -
2112            +How
2113             Bye
2114        "}),
2115            "{prompt}"
2116        );
2117
2118        respond_tx
2119            .send(model_response(indoc! {r#"
2120                --- a/root/foo.md
2121                +++ b/root/foo.md
2122                @@ ... @@
2123                 Hello!
2124                -How
2125                +How are you?
2126                 Bye
2127            "#}))
2128            .unwrap();
2129
2130        let prediction = prediction_task.await.unwrap().unwrap();
2131
2132        assert_eq!(prediction.edits.len(), 1);
2133        assert_eq!(
2134            prediction.edits[0].0.to_point(&snapshot).start,
2135            language::Point::new(1, 3)
2136        );
2137        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2138    }
2139
2140    // Skipped until we start including diagnostics in prompt
2141    // #[gpui::test]
2142    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2143    //     let (zeta, mut req_rx) = init_test(cx);
2144    //     let fs = FakeFs::new(cx.executor());
2145    //     fs.insert_tree(
2146    //         "/root",
2147    //         json!({
2148    //             "foo.md": "Hello!\nBye"
2149    //         }),
2150    //     )
2151    //     .await;
2152    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2153
2154    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2155    //     let diagnostic = lsp::Diagnostic {
2156    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2157    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2158    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2159    //         ..Default::default()
2160    //     };
2161
2162    //     project.update(cx, |project, cx| {
2163    //         project.lsp_store().update(cx, |lsp_store, cx| {
2164    //             // Create some diagnostics
2165    //             lsp_store
2166    //                 .update_diagnostics(
2167    //                     LanguageServerId(0),
2168    //                     lsp::PublishDiagnosticsParams {
2169    //                         uri: path_to_buffer_uri.clone(),
2170    //                         diagnostics: vec![diagnostic],
2171    //                         version: None,
2172    //                     },
2173    //                     None,
2174    //                     language::DiagnosticSourceKind::Pushed,
2175    //                     &[],
2176    //                     cx,
2177    //                 )
2178    //                 .unwrap();
2179    //         });
2180    //     });
2181
2182    //     let buffer = project
2183    //         .update(cx, |project, cx| {
2184    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2185    //             project.open_buffer(path, cx)
2186    //         })
2187    //         .await
2188    //         .unwrap();
2189
2190    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2191    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2192
2193    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2194    //         zeta.request_prediction(&project, &buffer, position, cx)
2195    //     });
2196
2197    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2198
2199    //     assert_eq!(request.diagnostic_groups.len(), 1);
2200    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2201    //         .unwrap();
2202    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2203    //     assert_eq!(
2204    //         value,
2205    //         json!({
2206    //             "entries": [{
2207    //                 "range": {
2208    //                     "start": 8,
2209    //                     "end": 10
2210    //                 },
2211    //                 "diagnostic": {
2212    //                     "source": null,
2213    //                     "code": null,
2214    //                     "code_description": null,
2215    //                     "severity": 1,
2216    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2217    //                     "markdown": null,
2218    //                     "group_id": 0,
2219    //                     "is_primary": true,
2220    //                     "is_disk_based": false,
2221    //                     "is_unnecessary": false,
2222    //                     "source_kind": "Pushed",
2223    //                     "data": null,
2224    //                     "underline": true
2225    //                 }
2226    //             }],
2227    //             "primary_ix": 0
2228    //         })
2229    //     );
2230    // }
2231
2232    fn model_response(text: &str) -> open_ai::Response {
2233        open_ai::Response {
2234            id: Uuid::new_v4().to_string(),
2235            object: "response".into(),
2236            created: 0,
2237            model: "model".into(),
2238            choices: vec![open_ai::Choice {
2239                index: 0,
2240                message: open_ai::RequestMessage::Assistant {
2241                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2242                    tool_calls: vec![],
2243                },
2244                finish_reason: None,
2245            }],
2246            usage: Usage {
2247                prompt_tokens: 0,
2248                completion_tokens: 0,
2249                total_tokens: 0,
2250            },
2251        }
2252    }
2253
2254    fn prompt_from_request(request: &open_ai::Request) -> &str {
2255        assert_eq!(request.messages.len(), 1);
2256        let open_ai::RequestMessage::User {
2257            content: open_ai::MessageContent::Plain(content),
2258            ..
2259        } = &request.messages[0]
2260        else {
2261            panic!(
2262                "Request does not have single user message of type Plain. {:#?}",
2263                request
2264            );
2265        };
2266        content
2267    }
2268
2269    fn init_test(
2270        cx: &mut TestAppContext,
2271    ) -> (
2272        Entity<Zeta>,
2273        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2274    ) {
2275        cx.update(move |cx| {
2276            let settings_store = SettingsStore::test(cx);
2277            cx.set_global(settings_store);
2278            zlog::init_test();
2279
2280            let (req_tx, req_rx) = mpsc::unbounded();
2281
2282            let http_client = FakeHttpClient::create({
2283                move |req| {
2284                    let uri = req.uri().path().to_string();
2285                    let mut body = req.into_body();
2286                    let req_tx = req_tx.clone();
2287                    async move {
2288                        let resp = match uri.as_str() {
2289                            "/client/llm_tokens" => serde_json::to_string(&json!({
2290                                "token": "test"
2291                            }))
2292                            .unwrap(),
2293                            "/predict_edits/raw" => {
2294                                let mut buf = Vec::new();
2295                                body.read_to_end(&mut buf).await.ok();
2296                                let req = serde_json::from_slice(&buf).unwrap();
2297
2298                                let (res_tx, res_rx) = oneshot::channel();
2299                                req_tx.unbounded_send((req, res_tx)).unwrap();
2300                                serde_json::to_string(&res_rx.await?).unwrap()
2301                            }
2302                            _ => {
2303                                panic!("Unexpected path: {}", uri)
2304                            }
2305                        };
2306
2307                        Ok(Response::builder().body(resp.into()).unwrap())
2308                    }
2309                }
2310            });
2311
2312            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2313            client.cloud_client().set_credentials(1, "test".into());
2314
2315            language_model::init(client.clone(), cx);
2316
2317            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2318            let zeta = Zeta::global(&client, &user_store, cx);
2319
2320            (zeta, req_rx)
2321        })
2322    }
2323}