zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod assemble_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50mod xml_edits;
  51
  52use crate::assemble_excerpts::assemble_excerpts;
  53use crate::prediction::EditPrediction;
  54pub use crate::prediction::EditPredictionId;
  55pub use provider::ZetaEditPredictionProvider;
  56
  57/// Maximum number of events to track.
  58const MAX_EVENT_COUNT: usize = 16;
  59
  60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  61    max_bytes: 512,
  62    min_bytes: 128,
  63    target_before_cursor_over_total_bytes: 0.5,
  64};
  65
  66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  67    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  68
  69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  70    excerpt: DEFAULT_EXCERPT_OPTIONS,
  71};
  72
  73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  74    EditPredictionContextOptions {
  75        use_imports: true,
  76        max_retrieved_declarations: 0,
  77        excerpt: DEFAULT_EXCERPT_OPTIONS,
  78        score: EditPredictionScoreOptions {
  79            omit_excerpt_overlaps: true,
  80        },
  81    };
  82
  83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  84    context: DEFAULT_CONTEXT_OPTIONS,
  85    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  86    max_diagnostic_bytes: 2048,
  87    prompt_format: PromptFormat::DEFAULT,
  88    file_indexing_parallelism: 1,
  89    buffer_change_grouping_interval: Duration::from_secs(1),
  90};
  91
  92static USE_OLLAMA: LazyLock<bool> =
  93    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  94static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  95    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
  96        "qwen3-coder:30b".to_string()
  97    } else {
  98        "yqvev8r3".to_string()
  99    })
 100});
 101static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 102    match env::var("ZED_ZETA2_MODEL").as_deref() {
 103        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 104        Ok(model) => model,
 105        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 106        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 107    }
 108    .to_string()
 109});
 110static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 111    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 112        if *USE_OLLAMA {
 113            Some("http://localhost:11434/v1/chat/completions".into())
 114        } else {
 115            None
 116        }
 117    })
 118});
 119
 120pub struct Zeta2FeatureFlag;
 121
 122impl FeatureFlag for Zeta2FeatureFlag {
 123    const NAME: &'static str = "zeta2";
 124
 125    fn enabled_for_staff() -> bool {
 126        false
 127    }
 128}
 129
 130#[derive(Clone)]
 131struct ZetaGlobal(Entity<Zeta>);
 132
 133impl Global for ZetaGlobal {}
 134
 135pub struct Zeta {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _llm_token_subscription: Subscription,
 140    projects: HashMap<EntityId, ZetaProject>,
 141    options: ZetaOptions,
 142    update_required: bool,
 143    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 144    #[cfg(feature = "eval-support")]
 145    eval_cache: Option<Arc<dyn EvalCache>>,
 146}
 147
 148#[derive(Debug, Clone, PartialEq)]
 149pub struct ZetaOptions {
 150    pub context: ContextMode,
 151    pub max_prompt_bytes: usize,
 152    pub max_diagnostic_bytes: usize,
 153    pub prompt_format: predict_edits_v3::PromptFormat,
 154    pub file_indexing_parallelism: usize,
 155    pub buffer_change_grouping_interval: Duration,
 156}
 157
 158#[derive(Debug, Clone, PartialEq)]
 159pub enum ContextMode {
 160    Agentic(AgenticContextOptions),
 161    Syntax(EditPredictionContextOptions),
 162}
 163
 164#[derive(Debug, Clone, PartialEq)]
 165pub struct AgenticContextOptions {
 166    pub excerpt: EditPredictionExcerptOptions,
 167}
 168
 169impl ContextMode {
 170    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 171        match self {
 172            ContextMode::Agentic(options) => &options.excerpt,
 173            ContextMode::Syntax(options) => &options.excerpt,
 174        }
 175    }
 176}
 177
 178#[derive(Debug)]
 179pub enum ZetaDebugInfo {
 180    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 181    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 182    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 183    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 184    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 185}
 186
 187#[derive(Debug)]
 188pub struct ZetaContextRetrievalStartedDebugInfo {
 189    pub project: Entity<Project>,
 190    pub timestamp: Instant,
 191    pub search_prompt: String,
 192}
 193
 194#[derive(Debug)]
 195pub struct ZetaContextRetrievalDebugInfo {
 196    pub project: Entity<Project>,
 197    pub timestamp: Instant,
 198}
 199
 200#[derive(Debug)]
 201pub struct ZetaEditPredictionDebugInfo {
 202    pub request: predict_edits_v3::PredictEditsRequest,
 203    pub retrieval_time: TimeDelta,
 204    pub buffer: WeakEntity<Buffer>,
 205    pub position: language::Anchor,
 206    pub local_prompt: Result<String, String>,
 207    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 208}
 209
 210#[derive(Debug)]
 211pub struct ZetaSearchQueryDebugInfo {
 212    pub project: Entity<Project>,
 213    pub timestamp: Instant,
 214    pub search_queries: Vec<SearchToolQuery>,
 215}
 216
 217pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 218
 219struct ZetaProject {
 220    syntax_index: Option<Entity<SyntaxIndex>>,
 221    events: VecDeque<Event>,
 222    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 223    current_prediction: Option<CurrentEditPrediction>,
 224    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 225    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 226    refresh_context_debounce_task: Option<Task<Option<()>>>,
 227    refresh_context_timestamp: Option<Instant>,
 228}
 229
 230#[derive(Debug, Clone)]
 231struct CurrentEditPrediction {
 232    pub requested_by_buffer_id: EntityId,
 233    pub prediction: EditPrediction,
 234}
 235
 236impl CurrentEditPrediction {
 237    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 238        let Some(new_edits) = self
 239            .prediction
 240            .interpolate(&self.prediction.buffer.read(cx))
 241        else {
 242            return false;
 243        };
 244
 245        if self.prediction.buffer != old_prediction.prediction.buffer {
 246            return true;
 247        }
 248
 249        let Some(old_edits) = old_prediction
 250            .prediction
 251            .interpolate(&old_prediction.prediction.buffer.read(cx))
 252        else {
 253            return true;
 254        };
 255
 256        // This reduces the occurrence of UI thrash from replacing edits
 257        //
 258        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 259        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 260            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 261            && old_edits.len() == 1
 262            && new_edits.len() == 1
 263        {
 264            let (old_range, old_text) = &old_edits[0];
 265            let (new_range, new_text) = &new_edits[0];
 266            new_range == old_range && new_text.starts_with(old_text.as_ref())
 267        } else {
 268            true
 269        }
 270    }
 271}
 272
 273/// A prediction from the perspective of a buffer.
 274#[derive(Debug)]
 275enum BufferEditPrediction<'a> {
 276    Local { prediction: &'a EditPrediction },
 277    Jump { prediction: &'a EditPrediction },
 278}
 279
 280struct RegisteredBuffer {
 281    snapshot: BufferSnapshot,
 282    _subscriptions: [gpui::Subscription; 2],
 283}
 284
 285#[derive(Clone)]
 286pub enum Event {
 287    BufferChange {
 288        old_snapshot: BufferSnapshot,
 289        new_snapshot: BufferSnapshot,
 290        timestamp: Instant,
 291    },
 292}
 293
 294impl Event {
 295    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 296        match self {
 297            Event::BufferChange {
 298                old_snapshot,
 299                new_snapshot,
 300                ..
 301            } => {
 302                let path = new_snapshot.file().map(|f| f.full_path(cx));
 303
 304                let old_path = old_snapshot.file().and_then(|f| {
 305                    let old_path = f.full_path(cx);
 306                    if Some(&old_path) != path.as_ref() {
 307                        Some(old_path)
 308                    } else {
 309                        None
 310                    }
 311                });
 312
 313                // TODO [zeta2] move to bg?
 314                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 315
 316                if path == old_path && diff.is_empty() {
 317                    None
 318                } else {
 319                    Some(predict_edits_v3::Event::BufferChange {
 320                        old_path,
 321                        path,
 322                        diff,
 323                        //todo: Actually detect if this edit was predicted or not
 324                        predicted: false,
 325                    })
 326                }
 327            }
 328        }
 329    }
 330}
 331
 332impl Zeta {
 333    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 334        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 335    }
 336
 337    pub fn global(
 338        client: &Arc<Client>,
 339        user_store: &Entity<UserStore>,
 340        cx: &mut App,
 341    ) -> Entity<Self> {
 342        cx.try_global::<ZetaGlobal>()
 343            .map(|global| global.0.clone())
 344            .unwrap_or_else(|| {
 345                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 346                cx.set_global(ZetaGlobal(zeta.clone()));
 347                zeta
 348            })
 349    }
 350
 351    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 352        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 353
 354        Self {
 355            projects: HashMap::default(),
 356            client,
 357            user_store,
 358            options: DEFAULT_OPTIONS,
 359            llm_token: LlmApiToken::default(),
 360            _llm_token_subscription: cx.subscribe(
 361                &refresh_llm_token_listener,
 362                |this, _listener, _event, cx| {
 363                    let client = this.client.clone();
 364                    let llm_token = this.llm_token.clone();
 365                    cx.spawn(async move |_this, _cx| {
 366                        llm_token.refresh(&client).await?;
 367                        anyhow::Ok(())
 368                    })
 369                    .detach_and_log_err(cx);
 370                },
 371            ),
 372            update_required: false,
 373            debug_tx: None,
 374            #[cfg(feature = "eval-support")]
 375            eval_cache: None,
 376        }
 377    }
 378
 379    #[cfg(feature = "eval-support")]
 380    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 381        self.eval_cache = Some(cache);
 382    }
 383
 384    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 385        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 386        self.debug_tx = Some(debug_watch_tx);
 387        debug_watch_rx
 388    }
 389
 390    pub fn options(&self) -> &ZetaOptions {
 391        &self.options
 392    }
 393
 394    pub fn set_options(&mut self, options: ZetaOptions) {
 395        self.options = options;
 396    }
 397
 398    pub fn clear_history(&mut self) {
 399        for zeta_project in self.projects.values_mut() {
 400            zeta_project.events.clear();
 401        }
 402    }
 403
 404    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 405        self.projects
 406            .get(&project.entity_id())
 407            .map(|project| project.events.iter())
 408            .into_iter()
 409            .flatten()
 410    }
 411
 412    pub fn context_for_project(
 413        &self,
 414        project: &Entity<Project>,
 415    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 416        self.projects
 417            .get(&project.entity_id())
 418            .and_then(|project| {
 419                Some(
 420                    project
 421                        .context
 422                        .as_ref()?
 423                        .iter()
 424                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 425                )
 426            })
 427            .into_iter()
 428            .flatten()
 429    }
 430
 431    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 432        self.user_store.read(cx).edit_prediction_usage()
 433    }
 434
 435    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 436        self.get_or_init_zeta_project(project, cx);
 437    }
 438
 439    pub fn register_buffer(
 440        &mut self,
 441        buffer: &Entity<Buffer>,
 442        project: &Entity<Project>,
 443        cx: &mut Context<Self>,
 444    ) {
 445        let zeta_project = self.get_or_init_zeta_project(project, cx);
 446        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 447    }
 448
 449    fn get_or_init_zeta_project(
 450        &mut self,
 451        project: &Entity<Project>,
 452        cx: &mut App,
 453    ) -> &mut ZetaProject {
 454        self.projects
 455            .entry(project.entity_id())
 456            .or_insert_with(|| ZetaProject {
 457                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 458                    Some(cx.new(|cx| {
 459                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 460                    }))
 461                } else {
 462                    None
 463                },
 464                events: VecDeque::new(),
 465                registered_buffers: HashMap::default(),
 466                current_prediction: None,
 467                context: None,
 468                refresh_context_task: None,
 469                refresh_context_debounce_task: None,
 470                refresh_context_timestamp: None,
 471            })
 472    }
 473
 474    fn register_buffer_impl<'a>(
 475        zeta_project: &'a mut ZetaProject,
 476        buffer: &Entity<Buffer>,
 477        project: &Entity<Project>,
 478        cx: &mut Context<Self>,
 479    ) -> &'a mut RegisteredBuffer {
 480        let buffer_id = buffer.entity_id();
 481        match zeta_project.registered_buffers.entry(buffer_id) {
 482            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 483            hash_map::Entry::Vacant(entry) => {
 484                let snapshot = buffer.read(cx).snapshot();
 485                let project_entity_id = project.entity_id();
 486                entry.insert(RegisteredBuffer {
 487                    snapshot,
 488                    _subscriptions: [
 489                        cx.subscribe(buffer, {
 490                            let project = project.downgrade();
 491                            move |this, buffer, event, cx| {
 492                                if let language::BufferEvent::Edited = event
 493                                    && let Some(project) = project.upgrade()
 494                                {
 495                                    this.report_changes_for_buffer(&buffer, &project, cx);
 496                                }
 497                            }
 498                        }),
 499                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 500                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 501                            else {
 502                                return;
 503                            };
 504                            zeta_project.registered_buffers.remove(&buffer_id);
 505                        }),
 506                    ],
 507                })
 508            }
 509        }
 510    }
 511
 512    fn report_changes_for_buffer(
 513        &mut self,
 514        buffer: &Entity<Buffer>,
 515        project: &Entity<Project>,
 516        cx: &mut Context<Self>,
 517    ) -> BufferSnapshot {
 518        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 519        let zeta_project = self.get_or_init_zeta_project(project, cx);
 520        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 521
 522        let new_snapshot = buffer.read(cx).snapshot();
 523        if new_snapshot.version != registered_buffer.snapshot.version {
 524            let old_snapshot =
 525                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 526            Self::push_event(
 527                zeta_project,
 528                buffer_change_grouping_interval,
 529                Event::BufferChange {
 530                    old_snapshot,
 531                    new_snapshot: new_snapshot.clone(),
 532                    timestamp: Instant::now(),
 533                },
 534            );
 535        }
 536
 537        new_snapshot
 538    }
 539
 540    fn push_event(
 541        zeta_project: &mut ZetaProject,
 542        buffer_change_grouping_interval: Duration,
 543        event: Event,
 544    ) {
 545        let events = &mut zeta_project.events;
 546
 547        if buffer_change_grouping_interval > Duration::ZERO
 548            && let Some(Event::BufferChange {
 549                new_snapshot: last_new_snapshot,
 550                timestamp: last_timestamp,
 551                ..
 552            }) = events.back_mut()
 553        {
 554            // Coalesce edits for the same buffer when they happen one after the other.
 555            let Event::BufferChange {
 556                old_snapshot,
 557                new_snapshot,
 558                timestamp,
 559            } = &event;
 560
 561            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 562                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 563                && old_snapshot.version == last_new_snapshot.version
 564            {
 565                *last_new_snapshot = new_snapshot.clone();
 566                *last_timestamp = *timestamp;
 567                return;
 568            }
 569        }
 570
 571        if events.len() >= MAX_EVENT_COUNT {
 572            // These are halved instead of popping to improve prompt caching.
 573            events.drain(..MAX_EVENT_COUNT / 2);
 574        }
 575
 576        events.push_back(event);
 577    }
 578
 579    fn current_prediction_for_buffer(
 580        &self,
 581        buffer: &Entity<Buffer>,
 582        project: &Entity<Project>,
 583        cx: &App,
 584    ) -> Option<BufferEditPrediction<'_>> {
 585        let project_state = self.projects.get(&project.entity_id())?;
 586
 587        let CurrentEditPrediction {
 588            requested_by_buffer_id,
 589            prediction,
 590        } = project_state.current_prediction.as_ref()?;
 591
 592        if prediction.targets_buffer(buffer.read(cx)) {
 593            Some(BufferEditPrediction::Local { prediction })
 594        } else if *requested_by_buffer_id == buffer.entity_id() {
 595            Some(BufferEditPrediction::Jump { prediction })
 596        } else {
 597            None
 598        }
 599    }
 600
 601    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 602        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 603            return;
 604        };
 605
 606        let Some(prediction) = project_state.current_prediction.take() else {
 607            return;
 608        };
 609        let request_id = prediction.prediction.id.to_string();
 610
 611        let client = self.client.clone();
 612        let llm_token = self.llm_token.clone();
 613        let app_version = AppVersion::global(cx);
 614        cx.spawn(async move |this, cx| {
 615            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 616                http_client::Url::parse(&predict_edits_url)?
 617            } else {
 618                client
 619                    .http_client()
 620                    .build_zed_llm_url("/predict_edits/accept", &[])?
 621            };
 622
 623            let response = cx
 624                .background_spawn(Self::send_api_request::<()>(
 625                    move |builder| {
 626                        let req = builder.uri(url.as_ref()).body(
 627                            serde_json::to_string(&AcceptEditPredictionBody {
 628                                request_id: request_id.clone(),
 629                            })?
 630                            .into(),
 631                        );
 632                        Ok(req?)
 633                    },
 634                    client,
 635                    llm_token,
 636                    app_version,
 637                ))
 638                .await;
 639
 640            Self::handle_api_response(&this, response, cx)?;
 641            anyhow::Ok(())
 642        })
 643        .detach_and_log_err(cx);
 644    }
 645
 646    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 647        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 648            project_state.current_prediction.take();
 649        };
 650    }
 651
 652    pub fn refresh_prediction(
 653        &mut self,
 654        project: &Entity<Project>,
 655        buffer: &Entity<Buffer>,
 656        position: language::Anchor,
 657        cx: &mut Context<Self>,
 658    ) -> Task<Result<()>> {
 659        let request_task = self.request_prediction(project, buffer, position, cx);
 660        let buffer = buffer.clone();
 661        let project = project.clone();
 662
 663        cx.spawn(async move |this, cx| {
 664            if let Some(prediction) = request_task.await? {
 665                this.update(cx, |this, cx| {
 666                    let project_state = this
 667                        .projects
 668                        .get_mut(&project.entity_id())
 669                        .context("Project not found")?;
 670
 671                    let new_prediction = CurrentEditPrediction {
 672                        requested_by_buffer_id: buffer.entity_id(),
 673                        prediction: prediction,
 674                    };
 675
 676                    if project_state
 677                        .current_prediction
 678                        .as_ref()
 679                        .is_none_or(|old_prediction| {
 680                            new_prediction.should_replace_prediction(&old_prediction, cx)
 681                        })
 682                    {
 683                        project_state.current_prediction = Some(new_prediction);
 684                    }
 685                    anyhow::Ok(())
 686                })??;
 687            }
 688            Ok(())
 689        })
 690    }
 691
 692    pub fn request_prediction(
 693        &mut self,
 694        project: &Entity<Project>,
 695        active_buffer: &Entity<Buffer>,
 696        position: language::Anchor,
 697        cx: &mut Context<Self>,
 698    ) -> Task<Result<Option<EditPrediction>>> {
 699        let project_state = self.projects.get(&project.entity_id());
 700
 701        let index_state = project_state.and_then(|state| {
 702            state
 703                .syntax_index
 704                .as_ref()
 705                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
 706        });
 707        let options = self.options.clone();
 708        let active_snapshot = active_buffer.read(cx).snapshot();
 709        let Some(excerpt_path) = active_snapshot
 710            .file()
 711            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 712        else {
 713            return Task::ready(Err(anyhow!("No file path for excerpt")));
 714        };
 715        let client = self.client.clone();
 716        let llm_token = self.llm_token.clone();
 717        let app_version = AppVersion::global(cx);
 718        let worktree_snapshots = project
 719            .read(cx)
 720            .worktrees(cx)
 721            .map(|worktree| worktree.read(cx).snapshot())
 722            .collect::<Vec<_>>();
 723        let debug_tx = self.debug_tx.clone();
 724
 725        let events = project_state
 726            .map(|state| {
 727                state
 728                    .events
 729                    .iter()
 730                    .filter_map(|event| event.to_request_event(cx))
 731                    .collect::<Vec<_>>()
 732            })
 733            .unwrap_or_default();
 734
 735        let diagnostics = active_snapshot.diagnostic_sets().clone();
 736
 737        let parent_abs_path =
 738            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 739                let mut path = f.worktree.read(cx).absolutize(&f.path);
 740                if path.pop() { Some(path) } else { None }
 741            });
 742
 743        // TODO data collection
 744        let can_collect_data = cx.is_staff();
 745
 746        let empty_context_files = HashMap::default();
 747        let context_files = project_state
 748            .and_then(|project_state| project_state.context.as_ref())
 749            .unwrap_or(&empty_context_files);
 750
 751        #[cfg(feature = "eval-support")]
 752        let parsed_fut = futures::future::join_all(
 753            context_files
 754                .keys()
 755                .map(|buffer| buffer.read(cx).parsing_idle()),
 756        );
 757
 758        let mut included_files = context_files
 759            .iter()
 760            .filter_map(|(buffer_entity, ranges)| {
 761                let buffer = buffer_entity.read(cx);
 762                Some((
 763                    buffer_entity.clone(),
 764                    buffer.snapshot(),
 765                    buffer.file()?.full_path(cx).into(),
 766                    ranges.clone(),
 767                ))
 768            })
 769            .collect::<Vec<_>>();
 770
 771        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
 772            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
 773        });
 774
 775        #[cfg(feature = "eval-support")]
 776        let eval_cache = self.eval_cache.clone();
 777
 778        let request_task = cx.background_spawn({
 779            let active_buffer = active_buffer.clone();
 780            async move {
 781                #[cfg(feature = "eval-support")]
 782                parsed_fut.await;
 783
 784                let index_state = if let Some(index_state) = index_state {
 785                    Some(index_state.lock_owned().await)
 786                } else {
 787                    None
 788                };
 789
 790                let cursor_offset = position.to_offset(&active_snapshot);
 791                let cursor_point = cursor_offset.to_point(&active_snapshot);
 792
 793                let before_retrieval = chrono::Utc::now();
 794
 795                let (diagnostic_groups, diagnostic_groups_truncated) =
 796                    Self::gather_nearby_diagnostics(
 797                        cursor_offset,
 798                        &diagnostics,
 799                        &active_snapshot,
 800                        options.max_diagnostic_bytes,
 801                    );
 802
 803                let cloud_request = match options.context {
 804                    ContextMode::Agentic(context_options) => {
 805                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 806                            cursor_point,
 807                            &active_snapshot,
 808                            &context_options.excerpt,
 809                            index_state.as_deref(),
 810                        ) else {
 811                            return Ok((None, None));
 812                        };
 813
 814                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 815                            ..active_snapshot.anchor_before(excerpt.range.end);
 816
 817                        if let Some(buffer_ix) =
 818                            included_files.iter().position(|(_, snapshot, _, _)| {
 819                                snapshot.remote_id() == active_snapshot.remote_id()
 820                            })
 821                        {
 822                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 823                            ranges.push(excerpt_anchor_range);
 824                            retrieval_search::merge_anchor_ranges(ranges, buffer);
 825                            let last_ix = included_files.len() - 1;
 826                            included_files.swap(buffer_ix, last_ix);
 827                        } else {
 828                            included_files.push((
 829                                active_buffer.clone(),
 830                                active_snapshot.clone(),
 831                                excerpt_path.clone(),
 832                                vec![excerpt_anchor_range],
 833                            ));
 834                        }
 835
 836                        let included_files = included_files
 837                            .iter()
 838                            .map(|(_, snapshot, path, ranges)| {
 839                                let ranges = ranges
 840                                    .iter()
 841                                    .map(|range| {
 842                                        let point_range = range.to_point(&snapshot);
 843                                        Line(point_range.start.row)..Line(point_range.end.row)
 844                                    })
 845                                    .collect::<Vec<_>>();
 846                                let excerpts = assemble_excerpts(&snapshot, ranges);
 847                                predict_edits_v3::IncludedFile {
 848                                    path: path.clone(),
 849                                    max_row: Line(snapshot.max_point().row),
 850                                    excerpts,
 851                                }
 852                            })
 853                            .collect::<Vec<_>>();
 854
 855                        predict_edits_v3::PredictEditsRequest {
 856                            excerpt_path,
 857                            excerpt: String::new(),
 858                            excerpt_line_range: Line(0)..Line(0),
 859                            excerpt_range: 0..0,
 860                            cursor_point: predict_edits_v3::Point {
 861                                line: predict_edits_v3::Line(cursor_point.row),
 862                                column: cursor_point.column,
 863                            },
 864                            included_files,
 865                            referenced_declarations: vec![],
 866                            events,
 867                            can_collect_data,
 868                            diagnostic_groups,
 869                            diagnostic_groups_truncated,
 870                            debug_info: debug_tx.is_some(),
 871                            prompt_max_bytes: Some(options.max_prompt_bytes),
 872                            prompt_format: options.prompt_format,
 873                            // TODO [zeta2]
 874                            signatures: vec![],
 875                            excerpt_parent: None,
 876                            git_info: None,
 877                        }
 878                    }
 879                    ContextMode::Syntax(context_options) => {
 880                        let Some(context) = EditPredictionContext::gather_context(
 881                            cursor_point,
 882                            &active_snapshot,
 883                            parent_abs_path.as_deref(),
 884                            &context_options,
 885                            index_state.as_deref(),
 886                        ) else {
 887                            return Ok((None, None));
 888                        };
 889
 890                        make_syntax_context_cloud_request(
 891                            excerpt_path,
 892                            context,
 893                            events,
 894                            can_collect_data,
 895                            diagnostic_groups,
 896                            diagnostic_groups_truncated,
 897                            None,
 898                            debug_tx.is_some(),
 899                            &worktree_snapshots,
 900                            index_state.as_deref(),
 901                            Some(options.max_prompt_bytes),
 902                            options.prompt_format,
 903                        )
 904                    }
 905                };
 906
 907                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 908
 909                let retrieval_time = chrono::Utc::now() - before_retrieval;
 910
 911                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 912                    let (response_tx, response_rx) = oneshot::channel();
 913
 914                    debug_tx
 915                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 916                            ZetaEditPredictionDebugInfo {
 917                                request: cloud_request.clone(),
 918                                retrieval_time,
 919                                buffer: active_buffer.downgrade(),
 920                                local_prompt: match prompt_result.as_ref() {
 921                                    Ok((prompt, _)) => Ok(prompt.clone()),
 922                                    Err(err) => Err(err.to_string()),
 923                                },
 924                                position,
 925                                response_rx,
 926                            },
 927                        ))
 928                        .ok();
 929                    Some(response_tx)
 930                } else {
 931                    None
 932                };
 933
 934                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 935                    if let Some(debug_response_tx) = debug_response_tx {
 936                        debug_response_tx
 937                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 938                            .ok();
 939                    }
 940                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 941                }
 942
 943                let (prompt, _) = prompt_result?;
 944                let request = open_ai::Request {
 945                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
 946                    messages: vec![open_ai::RequestMessage::User {
 947                        content: open_ai::MessageContent::Plain(prompt),
 948                    }],
 949                    stream: false,
 950                    max_completion_tokens: None,
 951                    stop: Default::default(),
 952                    temperature: 0.7,
 953                    tool_choice: None,
 954                    parallel_tool_calls: None,
 955                    tools: vec![],
 956                    prompt_cache_key: None,
 957                    reasoning_effort: None,
 958                };
 959
 960                log::trace!("Sending edit prediction request");
 961
 962                let before_request = chrono::Utc::now();
 963                let response = Self::send_raw_llm_request(
 964                    request,
 965                    client,
 966                    llm_token,
 967                    app_version,
 968                    #[cfg(feature = "eval-support")]
 969                    eval_cache,
 970                    #[cfg(feature = "eval-support")]
 971                    EvalCacheEntryKind::Prediction,
 972                )
 973                .await;
 974                let request_time = chrono::Utc::now() - before_request;
 975
 976                log::trace!("Got edit prediction response");
 977
 978                if let Some(debug_response_tx) = debug_response_tx {
 979                    debug_response_tx
 980                        .send((
 981                            response
 982                                .as_ref()
 983                                .map_err(|err| err.to_string())
 984                                .map(|response| response.0.clone()),
 985                            request_time,
 986                        ))
 987                        .ok();
 988                }
 989
 990                let (res, usage) = response?;
 991                let request_id = EditPredictionId(res.id.clone().into());
 992                let Some(mut output_text) = text_from_response(res) else {
 993                    return Ok((None, usage));
 994                };
 995
 996                if output_text.contains(CURSOR_MARKER) {
 997                    log::trace!("Stripping out {CURSOR_MARKER} from response");
 998                    output_text = output_text.replace(CURSOR_MARKER, "");
 999                }
1000
1001                let get_buffer_from_context = |path: &Path| {
1002                    included_files
1003                        .iter()
1004                        .find_map(|(_, buffer, probe_path, ranges)| {
1005                            if probe_path.as_ref() == path {
1006                                Some((buffer, ranges.as_slice()))
1007                            } else {
1008                                None
1009                            }
1010                        })
1011                };
1012
1013                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1014                    PromptFormat::NumLinesUniDiff => {
1015                        // TODO: Implement parsing of multi-file diffs
1016                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1017                    }
1018                    PromptFormat::Minimal => {
1019                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1020                            let edits = vec![];
1021                            (&active_snapshot, edits)
1022                        } else {
1023                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1024                        }
1025                    }
1026                    PromptFormat::OldTextNewText => {
1027                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1028                            .await?
1029                    }
1030                    _ => {
1031                        bail!("unsupported prompt format {}", options.prompt_format)
1032                    }
1033                };
1034
1035                let edited_buffer = included_files
1036                    .iter()
1037                    .find_map(|(buffer, snapshot, _, _)| {
1038                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1039                            Some(buffer.clone())
1040                        } else {
1041                            None
1042                        }
1043                    })
1044                    .context("Failed to find buffer in included_buffers")?;
1045
1046                anyhow::Ok((
1047                    Some((
1048                        request_id,
1049                        edited_buffer,
1050                        edited_buffer_snapshot.clone(),
1051                        edits,
1052                    )),
1053                    usage,
1054                ))
1055            }
1056        });
1057
1058        cx.spawn({
1059            async move |this, cx| {
1060                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1061                    Self::handle_api_response(&this, request_task.await, cx)?
1062                else {
1063                    return Ok(None);
1064                };
1065
1066                // TODO telemetry: duration, etc
1067                Ok(
1068                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1069                        .await,
1070                )
1071            }
1072        })
1073    }
1074
1075    async fn send_raw_llm_request(
1076        request: open_ai::Request,
1077        client: Arc<Client>,
1078        llm_token: LlmApiToken,
1079        app_version: SemanticVersion,
1080        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1081        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1082    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1083        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1084            http_client::Url::parse(&predict_edits_url)?
1085        } else {
1086            client
1087                .http_client()
1088                .build_zed_llm_url("/predict_edits/raw", &[])?
1089        };
1090
1091        #[cfg(feature = "eval-support")]
1092        let cache_key = if let Some(cache) = eval_cache {
1093            use collections::FxHasher;
1094            use std::hash::{Hash, Hasher};
1095
1096            let mut hasher = FxHasher::default();
1097            url.hash(&mut hasher);
1098            let request_str = serde_json::to_string_pretty(&request)?;
1099            request_str.hash(&mut hasher);
1100            let hash = hasher.finish();
1101
1102            let key = (eval_cache_kind, hash);
1103            if let Some(response_str) = cache.read(key) {
1104                return Ok((serde_json::from_str(&response_str)?, None));
1105            }
1106
1107            Some((cache, request_str, key))
1108        } else {
1109            None
1110        };
1111
1112        let (response, usage) = Self::send_api_request(
1113            |builder| {
1114                let req = builder
1115                    .uri(url.as_ref())
1116                    .body(serde_json::to_string(&request)?.into());
1117                Ok(req?)
1118            },
1119            client,
1120            llm_token,
1121            app_version,
1122        )
1123        .await?;
1124
1125        #[cfg(feature = "eval-support")]
1126        if let Some((cache, request, key)) = cache_key {
1127            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1128        }
1129
1130        Ok((response, usage))
1131    }
1132
1133    fn handle_api_response<T>(
1134        this: &WeakEntity<Self>,
1135        response: Result<(T, Option<EditPredictionUsage>)>,
1136        cx: &mut gpui::AsyncApp,
1137    ) -> Result<T> {
1138        match response {
1139            Ok((data, usage)) => {
1140                if let Some(usage) = usage {
1141                    this.update(cx, |this, cx| {
1142                        this.user_store.update(cx, |user_store, cx| {
1143                            user_store.update_edit_prediction_usage(usage, cx);
1144                        });
1145                    })
1146                    .ok();
1147                }
1148                Ok(data)
1149            }
1150            Err(err) => {
1151                if err.is::<ZedUpdateRequiredError>() {
1152                    cx.update(|cx| {
1153                        this.update(cx, |this, _cx| {
1154                            this.update_required = true;
1155                        })
1156                        .ok();
1157
1158                        let error_message: SharedString = err.to_string().into();
1159                        show_app_notification(
1160                            NotificationId::unique::<ZedUpdateRequiredError>(),
1161                            cx,
1162                            move |cx| {
1163                                cx.new(|cx| {
1164                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1165                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1166                                })
1167                            },
1168                        );
1169                    })
1170                    .ok();
1171                }
1172                Err(err)
1173            }
1174        }
1175    }
1176
1177    async fn send_api_request<Res>(
1178        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1179        client: Arc<Client>,
1180        llm_token: LlmApiToken,
1181        app_version: SemanticVersion,
1182    ) -> Result<(Res, Option<EditPredictionUsage>)>
1183    where
1184        Res: DeserializeOwned,
1185    {
1186        let http_client = client.http_client();
1187        let mut token = llm_token.acquire(&client).await?;
1188        let mut did_retry = false;
1189
1190        loop {
1191            let request_builder = http_client::Request::builder().method(Method::POST);
1192
1193            let request = build(
1194                request_builder
1195                    .header("Content-Type", "application/json")
1196                    .header("Authorization", format!("Bearer {}", token))
1197                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1198            )?;
1199
1200            let mut response = http_client.send(request).await?;
1201
1202            if let Some(minimum_required_version) = response
1203                .headers()
1204                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1205                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1206            {
1207                anyhow::ensure!(
1208                    app_version >= minimum_required_version,
1209                    ZedUpdateRequiredError {
1210                        minimum_version: minimum_required_version
1211                    }
1212                );
1213            }
1214
1215            if response.status().is_success() {
1216                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1217
1218                let mut body = Vec::new();
1219                response.body_mut().read_to_end(&mut body).await?;
1220                return Ok((serde_json::from_slice(&body)?, usage));
1221            } else if !did_retry
1222                && response
1223                    .headers()
1224                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1225                    .is_some()
1226            {
1227                did_retry = true;
1228                token = llm_token.refresh(&client).await?;
1229            } else {
1230                let mut body = String::new();
1231                response.body_mut().read_to_string(&mut body).await?;
1232                anyhow::bail!(
1233                    "Request failed with status: {:?}\nBody: {}",
1234                    response.status(),
1235                    body
1236                );
1237            }
1238        }
1239    }
1240
1241    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1242    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1243
1244    // Refresh the related excerpts when the user just beguns editing after
1245    // an idle period, and after they pause editing.
1246    fn refresh_context_if_needed(
1247        &mut self,
1248        project: &Entity<Project>,
1249        buffer: &Entity<language::Buffer>,
1250        cursor_position: language::Anchor,
1251        cx: &mut Context<Self>,
1252    ) {
1253        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1254            return;
1255        }
1256
1257        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1258            return;
1259        };
1260
1261        let now = Instant::now();
1262        let was_idle = zeta_project
1263            .refresh_context_timestamp
1264            .map_or(true, |timestamp| {
1265                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1266            });
1267        zeta_project.refresh_context_timestamp = Some(now);
1268        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1269            let buffer = buffer.clone();
1270            let project = project.clone();
1271            async move |this, cx| {
1272                if was_idle {
1273                    log::debug!("refetching edit prediction context after idle");
1274                } else {
1275                    cx.background_executor()
1276                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1277                        .await;
1278                    log::debug!("refetching edit prediction context after pause");
1279                }
1280                this.update(cx, |this, cx| {
1281                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1282
1283                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1284                        zeta_project.refresh_context_task = Some(task.log_err());
1285                    };
1286                })
1287                .ok()
1288            }
1289        }));
1290    }
1291
1292    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1293    // and avoid spawning more than one concurrent task.
1294    pub fn refresh_context(
1295        &mut self,
1296        project: Entity<Project>,
1297        buffer: Entity<language::Buffer>,
1298        cursor_position: language::Anchor,
1299        cx: &mut Context<Self>,
1300    ) -> Task<Result<()>> {
1301        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1302            return Task::ready(anyhow::Ok(()));
1303        };
1304
1305        let ContextMode::Agentic(options) = &self.options().context else {
1306            return Task::ready(anyhow::Ok(()));
1307        };
1308
1309        let snapshot = buffer.read(cx).snapshot();
1310        let cursor_point = cursor_position.to_point(&snapshot);
1311        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1312            cursor_point,
1313            &snapshot,
1314            &options.excerpt,
1315            None,
1316        ) else {
1317            return Task::ready(Ok(()));
1318        };
1319
1320        let app_version = AppVersion::global(cx);
1321        let client = self.client.clone();
1322        let llm_token = self.llm_token.clone();
1323        let debug_tx = self.debug_tx.clone();
1324        let current_file_path: Arc<Path> = snapshot
1325            .file()
1326            .map(|f| f.full_path(cx).into())
1327            .unwrap_or_else(|| Path::new("untitled").into());
1328
1329        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1330            predict_edits_v3::PlanContextRetrievalRequest {
1331                excerpt: cursor_excerpt.text(&snapshot).body,
1332                excerpt_path: current_file_path,
1333                excerpt_line_range: cursor_excerpt.line_range,
1334                cursor_file_max_row: Line(snapshot.max_point().row),
1335                events: zeta_project
1336                    .events
1337                    .iter()
1338                    .filter_map(|ev| ev.to_request_event(cx))
1339                    .collect(),
1340            },
1341        ) {
1342            Ok(prompt) => prompt,
1343            Err(err) => {
1344                return Task::ready(Err(err));
1345            }
1346        };
1347
1348        if let Some(debug_tx) = &debug_tx {
1349            debug_tx
1350                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1351                    ZetaContextRetrievalStartedDebugInfo {
1352                        project: project.clone(),
1353                        timestamp: Instant::now(),
1354                        search_prompt: prompt.clone(),
1355                    },
1356                ))
1357                .ok();
1358        }
1359
1360        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1361            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1362                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1363            );
1364
1365            let description = schema
1366                .get("description")
1367                .and_then(|description| description.as_str())
1368                .unwrap()
1369                .to_string();
1370
1371            (schema.into(), description)
1372        });
1373
1374        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1375
1376        let request = open_ai::Request {
1377            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1378            messages: vec![open_ai::RequestMessage::User {
1379                content: open_ai::MessageContent::Plain(prompt),
1380            }],
1381            stream: false,
1382            max_completion_tokens: None,
1383            stop: Default::default(),
1384            temperature: 0.7,
1385            tool_choice: None,
1386            parallel_tool_calls: None,
1387            tools: vec![open_ai::ToolDefinition::Function {
1388                function: FunctionDefinition {
1389                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1390                    description: Some(tool_description),
1391                    parameters: Some(tool_schema),
1392                },
1393            }],
1394            prompt_cache_key: None,
1395            reasoning_effort: None,
1396        };
1397
1398        #[cfg(feature = "eval-support")]
1399        let eval_cache = self.eval_cache.clone();
1400
1401        cx.spawn(async move |this, cx| {
1402            log::trace!("Sending search planning request");
1403            let response = Self::send_raw_llm_request(
1404                request,
1405                client,
1406                llm_token,
1407                app_version,
1408                #[cfg(feature = "eval-support")]
1409                eval_cache.clone(),
1410                #[cfg(feature = "eval-support")]
1411                EvalCacheEntryKind::Context,
1412            )
1413            .await;
1414            let mut response = Self::handle_api_response(&this, response, cx)?;
1415            log::trace!("Got search planning response");
1416
1417            let choice = response
1418                .choices
1419                .pop()
1420                .context("No choices in retrieval response")?;
1421            let open_ai::RequestMessage::Assistant {
1422                content: _,
1423                tool_calls,
1424            } = choice.message
1425            else {
1426                anyhow::bail!("Retrieval response didn't include an assistant message");
1427            };
1428
1429            let mut queries: Vec<SearchToolQuery> = Vec::new();
1430            for tool_call in tool_calls {
1431                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1432                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1433                    log::warn!(
1434                        "Context retrieval response tried to call an unknown tool: {}",
1435                        function.name
1436                    );
1437
1438                    continue;
1439                }
1440
1441                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1442                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1443                queries.extend(input.queries);
1444            }
1445
1446            if let Some(debug_tx) = &debug_tx {
1447                debug_tx
1448                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1449                        ZetaSearchQueryDebugInfo {
1450                            project: project.clone(),
1451                            timestamp: Instant::now(),
1452                            search_queries: queries.clone(),
1453                        },
1454                    ))
1455                    .ok();
1456            }
1457
1458            log::trace!("Running retrieval search: {queries:#?}");
1459
1460            let related_excerpts_result = retrieval_search::run_retrieval_searches(
1461                queries,
1462                project.clone(),
1463                #[cfg(feature = "eval-support")]
1464                eval_cache,
1465                cx,
1466            )
1467            .await;
1468
1469            log::trace!("Search queries executed");
1470
1471            if let Some(debug_tx) = &debug_tx {
1472                debug_tx
1473                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1474                        ZetaContextRetrievalDebugInfo {
1475                            project: project.clone(),
1476                            timestamp: Instant::now(),
1477                        },
1478                    ))
1479                    .ok();
1480            }
1481
1482            this.update(cx, |this, _cx| {
1483                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1484                    return Ok(());
1485                };
1486                zeta_project.refresh_context_task.take();
1487                if let Some(debug_tx) = &this.debug_tx {
1488                    debug_tx
1489                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1490                            ZetaContextRetrievalDebugInfo {
1491                                project,
1492                                timestamp: Instant::now(),
1493                            },
1494                        ))
1495                        .ok();
1496                }
1497                match related_excerpts_result {
1498                    Ok(excerpts) => {
1499                        zeta_project.context = Some(excerpts);
1500                        Ok(())
1501                    }
1502                    Err(error) => Err(error),
1503                }
1504            })?
1505        })
1506    }
1507
1508    pub fn set_context(
1509        &mut self,
1510        project: Entity<Project>,
1511        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1512    ) {
1513        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1514            zeta_project.context = Some(context);
1515        }
1516    }
1517
1518    fn gather_nearby_diagnostics(
1519        cursor_offset: usize,
1520        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1521        snapshot: &BufferSnapshot,
1522        max_diagnostics_bytes: usize,
1523    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1524        // TODO: Could make this more efficient
1525        let mut diagnostic_groups = Vec::new();
1526        for (language_server_id, diagnostics) in diagnostic_sets {
1527            let mut groups = Vec::new();
1528            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1529            diagnostic_groups.extend(
1530                groups
1531                    .into_iter()
1532                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1533            );
1534        }
1535
1536        // sort by proximity to cursor
1537        diagnostic_groups.sort_by_key(|group| {
1538            let range = &group.entries[group.primary_ix].range;
1539            if range.start >= cursor_offset {
1540                range.start - cursor_offset
1541            } else if cursor_offset >= range.end {
1542                cursor_offset - range.end
1543            } else {
1544                (cursor_offset - range.start).min(range.end - cursor_offset)
1545            }
1546        });
1547
1548        let mut results = Vec::new();
1549        let mut diagnostic_groups_truncated = false;
1550        let mut diagnostics_byte_count = 0;
1551        for group in diagnostic_groups {
1552            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1553            diagnostics_byte_count += raw_value.get().len();
1554            if diagnostics_byte_count > max_diagnostics_bytes {
1555                diagnostic_groups_truncated = true;
1556                break;
1557            }
1558            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1559        }
1560
1561        (results, diagnostic_groups_truncated)
1562    }
1563
1564    // TODO: Dedupe with similar code in request_prediction?
1565    pub fn cloud_request_for_zeta_cli(
1566        &mut self,
1567        project: &Entity<Project>,
1568        buffer: &Entity<Buffer>,
1569        position: language::Anchor,
1570        cx: &mut Context<Self>,
1571    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1572        let project_state = self.projects.get(&project.entity_id());
1573
1574        let index_state = project_state.and_then(|state| {
1575            state
1576                .syntax_index
1577                .as_ref()
1578                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1579        });
1580        let options = self.options.clone();
1581        let snapshot = buffer.read(cx).snapshot();
1582        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1583            return Task::ready(Err(anyhow!("No file path for excerpt")));
1584        };
1585        let worktree_snapshots = project
1586            .read(cx)
1587            .worktrees(cx)
1588            .map(|worktree| worktree.read(cx).snapshot())
1589            .collect::<Vec<_>>();
1590
1591        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1592            let mut path = f.worktree.read(cx).absolutize(&f.path);
1593            if path.pop() { Some(path) } else { None }
1594        });
1595
1596        cx.background_spawn(async move {
1597            let index_state = if let Some(index_state) = index_state {
1598                Some(index_state.lock_owned().await)
1599            } else {
1600                None
1601            };
1602
1603            let cursor_point = position.to_point(&snapshot);
1604
1605            let debug_info = true;
1606            EditPredictionContext::gather_context(
1607                cursor_point,
1608                &snapshot,
1609                parent_abs_path.as_deref(),
1610                match &options.context {
1611                    ContextMode::Agentic(_) => {
1612                        // TODO
1613                        panic!("Llm mode not supported in zeta cli yet");
1614                    }
1615                    ContextMode::Syntax(edit_prediction_context_options) => {
1616                        edit_prediction_context_options
1617                    }
1618                },
1619                index_state.as_deref(),
1620            )
1621            .context("Failed to select excerpt")
1622            .map(|context| {
1623                make_syntax_context_cloud_request(
1624                    excerpt_path.into(),
1625                    context,
1626                    // TODO pass everything
1627                    Vec::new(),
1628                    false,
1629                    Vec::new(),
1630                    false,
1631                    None,
1632                    debug_info,
1633                    &worktree_snapshots,
1634                    index_state.as_deref(),
1635                    Some(options.max_prompt_bytes),
1636                    options.prompt_format,
1637                )
1638            })
1639        })
1640    }
1641
1642    pub fn wait_for_initial_indexing(
1643        &mut self,
1644        project: &Entity<Project>,
1645        cx: &mut App,
1646    ) -> Task<Result<()>> {
1647        let zeta_project = self.get_or_init_zeta_project(project, cx);
1648        if let Some(syntax_index) = &zeta_project.syntax_index {
1649            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1650        } else {
1651            Task::ready(Ok(()))
1652        }
1653    }
1654}
1655
1656pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1657    let choice = res.choices.pop()?;
1658    let output_text = match choice.message {
1659        open_ai::RequestMessage::Assistant {
1660            content: Some(open_ai::MessageContent::Plain(content)),
1661            ..
1662        } => content,
1663        open_ai::RequestMessage::Assistant {
1664            content: Some(open_ai::MessageContent::Multipart(mut content)),
1665            ..
1666        } => {
1667            if content.is_empty() {
1668                log::error!("No output from Baseten completion response");
1669                return None;
1670            }
1671
1672            match content.remove(0) {
1673                open_ai::MessagePart::Text { text } => text,
1674                open_ai::MessagePart::Image { .. } => {
1675                    log::error!("Expected text, got an image");
1676                    return None;
1677                }
1678            }
1679        }
1680        _ => {
1681            log::error!("Invalid response message: {:?}", choice.message);
1682            return None;
1683        }
1684    };
1685    Some(output_text)
1686}
1687
1688#[derive(Error, Debug)]
1689#[error(
1690    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1691)]
1692pub struct ZedUpdateRequiredError {
1693    minimum_version: SemanticVersion,
1694}
1695
1696fn make_syntax_context_cloud_request(
1697    excerpt_path: Arc<Path>,
1698    context: EditPredictionContext,
1699    events: Vec<predict_edits_v3::Event>,
1700    can_collect_data: bool,
1701    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1702    diagnostic_groups_truncated: bool,
1703    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1704    debug_info: bool,
1705    worktrees: &Vec<worktree::Snapshot>,
1706    index_state: Option<&SyntaxIndexState>,
1707    prompt_max_bytes: Option<usize>,
1708    prompt_format: PromptFormat,
1709) -> predict_edits_v3::PredictEditsRequest {
1710    let mut signatures = Vec::new();
1711    let mut declaration_to_signature_index = HashMap::default();
1712    let mut referenced_declarations = Vec::new();
1713
1714    for snippet in context.declarations {
1715        let project_entry_id = snippet.declaration.project_entry_id();
1716        let Some(path) = worktrees.iter().find_map(|worktree| {
1717            worktree.entry_for_id(project_entry_id).map(|entry| {
1718                let mut full_path = RelPathBuf::new();
1719                full_path.push(worktree.root_name());
1720                full_path.push(&entry.path);
1721                full_path
1722            })
1723        }) else {
1724            continue;
1725        };
1726
1727        let parent_index = index_state.and_then(|index_state| {
1728            snippet.declaration.parent().and_then(|parent| {
1729                add_signature(
1730                    parent,
1731                    &mut declaration_to_signature_index,
1732                    &mut signatures,
1733                    index_state,
1734                )
1735            })
1736        });
1737
1738        let (text, text_is_truncated) = snippet.declaration.item_text();
1739        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1740            path: path.as_std_path().into(),
1741            text: text.into(),
1742            range: snippet.declaration.item_line_range(),
1743            text_is_truncated,
1744            signature_range: snippet.declaration.signature_range_in_item_text(),
1745            parent_index,
1746            signature_score: snippet.score(DeclarationStyle::Signature),
1747            declaration_score: snippet.score(DeclarationStyle::Declaration),
1748            score_components: snippet.components,
1749        });
1750    }
1751
1752    let excerpt_parent = index_state.and_then(|index_state| {
1753        context
1754            .excerpt
1755            .parent_declarations
1756            .last()
1757            .and_then(|(parent, _)| {
1758                add_signature(
1759                    *parent,
1760                    &mut declaration_to_signature_index,
1761                    &mut signatures,
1762                    index_state,
1763                )
1764            })
1765    });
1766
1767    predict_edits_v3::PredictEditsRequest {
1768        excerpt_path,
1769        excerpt: context.excerpt_text.body,
1770        excerpt_line_range: context.excerpt.line_range,
1771        excerpt_range: context.excerpt.range,
1772        cursor_point: predict_edits_v3::Point {
1773            line: predict_edits_v3::Line(context.cursor_point.row),
1774            column: context.cursor_point.column,
1775        },
1776        referenced_declarations,
1777        included_files: vec![],
1778        signatures,
1779        excerpt_parent,
1780        events,
1781        can_collect_data,
1782        diagnostic_groups,
1783        diagnostic_groups_truncated,
1784        git_info,
1785        debug_info,
1786        prompt_max_bytes,
1787        prompt_format,
1788    }
1789}
1790
1791fn add_signature(
1792    declaration_id: DeclarationId,
1793    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1794    signatures: &mut Vec<Signature>,
1795    index: &SyntaxIndexState,
1796) -> Option<usize> {
1797    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1798        return Some(*signature_index);
1799    }
1800    let Some(parent_declaration) = index.declaration(declaration_id) else {
1801        log::error!("bug: missing parent declaration");
1802        return None;
1803    };
1804    let parent_index = parent_declaration.parent().and_then(|parent| {
1805        add_signature(parent, declaration_to_signature_index, signatures, index)
1806    });
1807    let (text, text_is_truncated) = parent_declaration.signature_text();
1808    let signature_index = signatures.len();
1809    signatures.push(Signature {
1810        text: text.into(),
1811        text_is_truncated,
1812        parent_index,
1813        range: parent_declaration.signature_line_range(),
1814    });
1815    declaration_to_signature_index.insert(declaration_id, signature_index);
1816    Some(signature_index)
1817}
1818
1819#[cfg(feature = "eval-support")]
1820pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1821
1822#[cfg(feature = "eval-support")]
1823#[derive(Debug, Clone, Copy, PartialEq)]
1824pub enum EvalCacheEntryKind {
1825    Context,
1826    Search,
1827    Prediction,
1828}
1829
1830#[cfg(feature = "eval-support")]
1831impl std::fmt::Display for EvalCacheEntryKind {
1832    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1833        match self {
1834            EvalCacheEntryKind::Search => write!(f, "search"),
1835            EvalCacheEntryKind::Context => write!(f, "context"),
1836            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1837        }
1838    }
1839}
1840
1841#[cfg(feature = "eval-support")]
1842pub trait EvalCache: Send + Sync {
1843    fn read(&self, key: EvalCacheKey) -> Option<String>;
1844    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1845}
1846
1847#[cfg(test)]
1848mod tests {
1849    use std::{path::Path, sync::Arc};
1850
1851    use client::UserStore;
1852    use clock::FakeSystemClock;
1853    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1854    use futures::{
1855        AsyncReadExt, StreamExt,
1856        channel::{mpsc, oneshot},
1857    };
1858    use gpui::{
1859        Entity, TestAppContext,
1860        http_client::{FakeHttpClient, Response},
1861        prelude::*,
1862    };
1863    use indoc::indoc;
1864    use language::OffsetRangeExt as _;
1865    use open_ai::Usage;
1866    use pretty_assertions::{assert_eq, assert_matches};
1867    use project::{FakeFs, Project};
1868    use serde_json::json;
1869    use settings::SettingsStore;
1870    use util::path;
1871    use uuid::Uuid;
1872
1873    use crate::{BufferEditPrediction, Zeta};
1874
1875    #[gpui::test]
1876    async fn test_current_state(cx: &mut TestAppContext) {
1877        let (zeta, mut req_rx) = init_test(cx);
1878        let fs = FakeFs::new(cx.executor());
1879        fs.insert_tree(
1880            "/root",
1881            json!({
1882                "1.txt": "Hello!\nHow\nBye\n",
1883                "2.txt": "Hola!\nComo\nAdios\n"
1884            }),
1885        )
1886        .await;
1887        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1888
1889        zeta.update(cx, |zeta, cx| {
1890            zeta.register_project(&project, cx);
1891        });
1892
1893        let buffer1 = project
1894            .update(cx, |project, cx| {
1895                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1896                project.open_buffer(path, cx)
1897            })
1898            .await
1899            .unwrap();
1900        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1901        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1902
1903        // Prediction for current file
1904
1905        let prediction_task = zeta.update(cx, |zeta, cx| {
1906            zeta.refresh_prediction(&project, &buffer1, position, cx)
1907        });
1908        let (_request, respond_tx) = req_rx.next().await.unwrap();
1909
1910        respond_tx
1911            .send(model_response(indoc! {r"
1912                --- a/root/1.txt
1913                +++ b/root/1.txt
1914                @@ ... @@
1915                 Hello!
1916                -How
1917                +How are you?
1918                 Bye
1919            "}))
1920            .unwrap();
1921        prediction_task.await.unwrap();
1922
1923        zeta.read_with(cx, |zeta, cx| {
1924            let prediction = zeta
1925                .current_prediction_for_buffer(&buffer1, &project, cx)
1926                .unwrap();
1927            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1928        });
1929
1930        // Context refresh
1931        let refresh_task = zeta.update(cx, |zeta, cx| {
1932            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1933        });
1934        let (_request, respond_tx) = req_rx.next().await.unwrap();
1935        respond_tx
1936            .send(open_ai::Response {
1937                id: Uuid::new_v4().to_string(),
1938                object: "response".into(),
1939                created: 0,
1940                model: "model".into(),
1941                choices: vec![open_ai::Choice {
1942                    index: 0,
1943                    message: open_ai::RequestMessage::Assistant {
1944                        content: None,
1945                        tool_calls: vec![open_ai::ToolCall {
1946                            id: "search".into(),
1947                            content: open_ai::ToolCallContent::Function {
1948                                function: open_ai::FunctionContent {
1949                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1950                                        .to_string(),
1951                                    arguments: serde_json::to_string(&SearchToolInput {
1952                                        queries: Box::new([SearchToolQuery {
1953                                            glob: "root/2.txt".to_string(),
1954                                            syntax_node: vec![],
1955                                            content: Some(".".into()),
1956                                        }]),
1957                                    })
1958                                    .unwrap(),
1959                                },
1960                            },
1961                        }],
1962                    },
1963                    finish_reason: None,
1964                }],
1965                usage: Usage {
1966                    prompt_tokens: 0,
1967                    completion_tokens: 0,
1968                    total_tokens: 0,
1969                },
1970            })
1971            .unwrap();
1972        refresh_task.await.unwrap();
1973
1974        zeta.update(cx, |zeta, _cx| {
1975            zeta.discard_current_prediction(&project);
1976        });
1977
1978        // Prediction for another file
1979        let prediction_task = zeta.update(cx, |zeta, cx| {
1980            zeta.refresh_prediction(&project, &buffer1, position, cx)
1981        });
1982        let (_request, respond_tx) = req_rx.next().await.unwrap();
1983        respond_tx
1984            .send(model_response(indoc! {r#"
1985                --- a/root/2.txt
1986                +++ b/root/2.txt
1987                 Hola!
1988                -Como
1989                +Como estas?
1990                 Adios
1991            "#}))
1992            .unwrap();
1993        prediction_task.await.unwrap();
1994        zeta.read_with(cx, |zeta, cx| {
1995            let prediction = zeta
1996                .current_prediction_for_buffer(&buffer1, &project, cx)
1997                .unwrap();
1998            assert_matches!(
1999                prediction,
2000                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2001            );
2002        });
2003
2004        let buffer2 = project
2005            .update(cx, |project, cx| {
2006                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2007                project.open_buffer(path, cx)
2008            })
2009            .await
2010            .unwrap();
2011
2012        zeta.read_with(cx, |zeta, cx| {
2013            let prediction = zeta
2014                .current_prediction_for_buffer(&buffer2, &project, cx)
2015                .unwrap();
2016            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2017        });
2018    }
2019
2020    #[gpui::test]
2021    async fn test_simple_request(cx: &mut TestAppContext) {
2022        let (zeta, mut req_rx) = init_test(cx);
2023        let fs = FakeFs::new(cx.executor());
2024        fs.insert_tree(
2025            "/root",
2026            json!({
2027                "foo.md":  "Hello!\nHow\nBye\n"
2028            }),
2029        )
2030        .await;
2031        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2032
2033        let buffer = project
2034            .update(cx, |project, cx| {
2035                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2036                project.open_buffer(path, cx)
2037            })
2038            .await
2039            .unwrap();
2040        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2041        let position = snapshot.anchor_before(language::Point::new(1, 3));
2042
2043        let prediction_task = zeta.update(cx, |zeta, cx| {
2044            zeta.request_prediction(&project, &buffer, position, cx)
2045        });
2046
2047        let (_, respond_tx) = req_rx.next().await.unwrap();
2048
2049        // TODO Put back when we have a structured request again
2050        // assert_eq!(
2051        //     request.excerpt_path.as_ref(),
2052        //     Path::new(path!("root/foo.md"))
2053        // );
2054        // assert_eq!(
2055        //     request.cursor_point,
2056        //     Point {
2057        //         line: Line(1),
2058        //         column: 3
2059        //     }
2060        // );
2061
2062        respond_tx
2063            .send(model_response(indoc! { r"
2064                --- a/root/foo.md
2065                +++ b/root/foo.md
2066                @@ ... @@
2067                 Hello!
2068                -How
2069                +How are you?
2070                 Bye
2071            "}))
2072            .unwrap();
2073
2074        let prediction = prediction_task.await.unwrap().unwrap();
2075
2076        assert_eq!(prediction.edits.len(), 1);
2077        assert_eq!(
2078            prediction.edits[0].0.to_point(&snapshot).start,
2079            language::Point::new(1, 3)
2080        );
2081        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2082    }
2083
2084    #[gpui::test]
2085    async fn test_request_events(cx: &mut TestAppContext) {
2086        let (zeta, mut req_rx) = init_test(cx);
2087        let fs = FakeFs::new(cx.executor());
2088        fs.insert_tree(
2089            "/root",
2090            json!({
2091                "foo.md": "Hello!\n\nBye\n"
2092            }),
2093        )
2094        .await;
2095        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2096
2097        let buffer = project
2098            .update(cx, |project, cx| {
2099                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2100                project.open_buffer(path, cx)
2101            })
2102            .await
2103            .unwrap();
2104
2105        zeta.update(cx, |zeta, cx| {
2106            zeta.register_buffer(&buffer, &project, cx);
2107        });
2108
2109        buffer.update(cx, |buffer, cx| {
2110            buffer.edit(vec![(7..7, "How")], None, cx);
2111        });
2112
2113        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2114        let position = snapshot.anchor_before(language::Point::new(1, 3));
2115
2116        let prediction_task = zeta.update(cx, |zeta, cx| {
2117            zeta.request_prediction(&project, &buffer, position, cx)
2118        });
2119
2120        let (request, respond_tx) = req_rx.next().await.unwrap();
2121
2122        let prompt = prompt_from_request(&request);
2123        assert!(
2124            prompt.contains(indoc! {"
2125            --- a/root/foo.md
2126            +++ b/root/foo.md
2127            @@ -1,3 +1,3 @@
2128             Hello!
2129            -
2130            +How
2131             Bye
2132        "}),
2133            "{prompt}"
2134        );
2135
2136        respond_tx
2137            .send(model_response(indoc! {r#"
2138                --- a/root/foo.md
2139                +++ b/root/foo.md
2140                @@ ... @@
2141                 Hello!
2142                -How
2143                +How are you?
2144                 Bye
2145            "#}))
2146            .unwrap();
2147
2148        let prediction = prediction_task.await.unwrap().unwrap();
2149
2150        assert_eq!(prediction.edits.len(), 1);
2151        assert_eq!(
2152            prediction.edits[0].0.to_point(&snapshot).start,
2153            language::Point::new(1, 3)
2154        );
2155        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2156    }
2157
2158    // Skipped until we start including diagnostics in prompt
2159    // #[gpui::test]
2160    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2161    //     let (zeta, mut req_rx) = init_test(cx);
2162    //     let fs = FakeFs::new(cx.executor());
2163    //     fs.insert_tree(
2164    //         "/root",
2165    //         json!({
2166    //             "foo.md": "Hello!\nBye"
2167    //         }),
2168    //     )
2169    //     .await;
2170    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2171
2172    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2173    //     let diagnostic = lsp::Diagnostic {
2174    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2175    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2176    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2177    //         ..Default::default()
2178    //     };
2179
2180    //     project.update(cx, |project, cx| {
2181    //         project.lsp_store().update(cx, |lsp_store, cx| {
2182    //             // Create some diagnostics
2183    //             lsp_store
2184    //                 .update_diagnostics(
2185    //                     LanguageServerId(0),
2186    //                     lsp::PublishDiagnosticsParams {
2187    //                         uri: path_to_buffer_uri.clone(),
2188    //                         diagnostics: vec![diagnostic],
2189    //                         version: None,
2190    //                     },
2191    //                     None,
2192    //                     language::DiagnosticSourceKind::Pushed,
2193    //                     &[],
2194    //                     cx,
2195    //                 )
2196    //                 .unwrap();
2197    //         });
2198    //     });
2199
2200    //     let buffer = project
2201    //         .update(cx, |project, cx| {
2202    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2203    //             project.open_buffer(path, cx)
2204    //         })
2205    //         .await
2206    //         .unwrap();
2207
2208    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2209    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2210
2211    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2212    //         zeta.request_prediction(&project, &buffer, position, cx)
2213    //     });
2214
2215    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2216
2217    //     assert_eq!(request.diagnostic_groups.len(), 1);
2218    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2219    //         .unwrap();
2220    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2221    //     assert_eq!(
2222    //         value,
2223    //         json!({
2224    //             "entries": [{
2225    //                 "range": {
2226    //                     "start": 8,
2227    //                     "end": 10
2228    //                 },
2229    //                 "diagnostic": {
2230    //                     "source": null,
2231    //                     "code": null,
2232    //                     "code_description": null,
2233    //                     "severity": 1,
2234    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2235    //                     "markdown": null,
2236    //                     "group_id": 0,
2237    //                     "is_primary": true,
2238    //                     "is_disk_based": false,
2239    //                     "is_unnecessary": false,
2240    //                     "source_kind": "Pushed",
2241    //                     "data": null,
2242    //                     "underline": true
2243    //                 }
2244    //             }],
2245    //             "primary_ix": 0
2246    //         })
2247    //     );
2248    // }
2249
2250    fn model_response(text: &str) -> open_ai::Response {
2251        open_ai::Response {
2252            id: Uuid::new_v4().to_string(),
2253            object: "response".into(),
2254            created: 0,
2255            model: "model".into(),
2256            choices: vec![open_ai::Choice {
2257                index: 0,
2258                message: open_ai::RequestMessage::Assistant {
2259                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2260                    tool_calls: vec![],
2261                },
2262                finish_reason: None,
2263            }],
2264            usage: Usage {
2265                prompt_tokens: 0,
2266                completion_tokens: 0,
2267                total_tokens: 0,
2268            },
2269        }
2270    }
2271
2272    fn prompt_from_request(request: &open_ai::Request) -> &str {
2273        assert_eq!(request.messages.len(), 1);
2274        let open_ai::RequestMessage::User {
2275            content: open_ai::MessageContent::Plain(content),
2276            ..
2277        } = &request.messages[0]
2278        else {
2279            panic!(
2280                "Request does not have single user message of type Plain. {:#?}",
2281                request
2282            );
2283        };
2284        content
2285    }
2286
2287    fn init_test(
2288        cx: &mut TestAppContext,
2289    ) -> (
2290        Entity<Zeta>,
2291        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2292    ) {
2293        cx.update(move |cx| {
2294            let settings_store = SettingsStore::test(cx);
2295            cx.set_global(settings_store);
2296            zlog::init_test();
2297
2298            let (req_tx, req_rx) = mpsc::unbounded();
2299
2300            let http_client = FakeHttpClient::create({
2301                move |req| {
2302                    let uri = req.uri().path().to_string();
2303                    let mut body = req.into_body();
2304                    let req_tx = req_tx.clone();
2305                    async move {
2306                        let resp = match uri.as_str() {
2307                            "/client/llm_tokens" => serde_json::to_string(&json!({
2308                                "token": "test"
2309                            }))
2310                            .unwrap(),
2311                            "/predict_edits/raw" => {
2312                                let mut buf = Vec::new();
2313                                body.read_to_end(&mut buf).await.ok();
2314                                let req = serde_json::from_slice(&buf).unwrap();
2315
2316                                let (res_tx, res_rx) = oneshot::channel();
2317                                req_tx.unbounded_send((req, res_tx)).unwrap();
2318                                serde_json::to_string(&res_rx.await?).unwrap()
2319                            }
2320                            _ => {
2321                                panic!("Unexpected path: {}", uri)
2322                            }
2323                        };
2324
2325                        Ok(Response::builder().body(resp.into()).unwrap())
2326                    }
2327                }
2328            });
2329
2330            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2331            client.cloud_client().set_credentials(1, "test".into());
2332
2333            language_model::init(client.clone(), cx);
2334
2335            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2336            let zeta = Zeta::global(&client, &user_store, cx);
2337
2338            (zeta, req_rx)
2339        })
2340    }
2341}