zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50mod xml_edits;
  51
  52use crate::merge_excerpts::merge_excerpts;
  53use crate::prediction::EditPrediction;
  54pub use crate::prediction::EditPredictionId;
  55pub use provider::ZetaEditPredictionProvider;
  56
  57/// Maximum number of events to track.
  58const MAX_EVENT_COUNT: usize = 16;
  59
  60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  61    max_bytes: 512,
  62    min_bytes: 128,
  63    target_before_cursor_over_total_bytes: 0.5,
  64};
  65
  66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  67    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  68
  69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  70    excerpt: DEFAULT_EXCERPT_OPTIONS,
  71};
  72
  73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  74    EditPredictionContextOptions {
  75        use_imports: true,
  76        max_retrieved_declarations: 0,
  77        excerpt: DEFAULT_EXCERPT_OPTIONS,
  78        score: EditPredictionScoreOptions {
  79            omit_excerpt_overlaps: true,
  80        },
  81    };
  82
  83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  84    context: DEFAULT_CONTEXT_OPTIONS,
  85    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  86    max_diagnostic_bytes: 2048,
  87    prompt_format: PromptFormat::DEFAULT,
  88    file_indexing_parallelism: 1,
  89    buffer_change_grouping_interval: Duration::from_secs(1),
  90};
  91
  92static USE_OLLAMA: LazyLock<bool> =
  93    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  94static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  95    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
  96        "qwen3-coder:30b".to_string()
  97    } else {
  98        "yqvev8r3".to_string()
  99    })
 100});
 101static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 102    match env::var("ZED_ZETA2_MODEL").as_deref() {
 103        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 104        Ok(model) => model,
 105        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 106        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 107    }
 108    .to_string()
 109});
 110static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 111    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 112        if *USE_OLLAMA {
 113            Some("http://localhost:11434/v1/chat/completions".into())
 114        } else {
 115            None
 116        }
 117    })
 118});
 119
 120pub struct Zeta2FeatureFlag;
 121
 122impl FeatureFlag for Zeta2FeatureFlag {
 123    const NAME: &'static str = "zeta2";
 124
 125    fn enabled_for_staff() -> bool {
 126        false
 127    }
 128}
 129
 130#[derive(Clone)]
 131struct ZetaGlobal(Entity<Zeta>);
 132
 133impl Global for ZetaGlobal {}
 134
 135pub struct Zeta {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _llm_token_subscription: Subscription,
 140    projects: HashMap<EntityId, ZetaProject>,
 141    options: ZetaOptions,
 142    update_required: bool,
 143    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 144    #[cfg(feature = "eval-support")]
 145    eval_cache: Option<Arc<dyn EvalCache>>,
 146}
 147
 148#[derive(Debug, Clone, PartialEq)]
 149pub struct ZetaOptions {
 150    pub context: ContextMode,
 151    pub max_prompt_bytes: usize,
 152    pub max_diagnostic_bytes: usize,
 153    pub prompt_format: predict_edits_v3::PromptFormat,
 154    pub file_indexing_parallelism: usize,
 155    pub buffer_change_grouping_interval: Duration,
 156}
 157
 158#[derive(Debug, Clone, PartialEq)]
 159pub enum ContextMode {
 160    Agentic(AgenticContextOptions),
 161    Syntax(EditPredictionContextOptions),
 162}
 163
 164#[derive(Debug, Clone, PartialEq)]
 165pub struct AgenticContextOptions {
 166    pub excerpt: EditPredictionExcerptOptions,
 167}
 168
 169impl ContextMode {
 170    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 171        match self {
 172            ContextMode::Agentic(options) => &options.excerpt,
 173            ContextMode::Syntax(options) => &options.excerpt,
 174        }
 175    }
 176}
 177
 178#[derive(Debug)]
 179pub enum ZetaDebugInfo {
 180    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 181    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 182    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 183    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 184    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 185}
 186
 187#[derive(Debug)]
 188pub struct ZetaContextRetrievalStartedDebugInfo {
 189    pub project: Entity<Project>,
 190    pub timestamp: Instant,
 191    pub search_prompt: String,
 192}
 193
 194#[derive(Debug)]
 195pub struct ZetaContextRetrievalDebugInfo {
 196    pub project: Entity<Project>,
 197    pub timestamp: Instant,
 198}
 199
 200#[derive(Debug)]
 201pub struct ZetaEditPredictionDebugInfo {
 202    pub request: predict_edits_v3::PredictEditsRequest,
 203    pub retrieval_time: TimeDelta,
 204    pub buffer: WeakEntity<Buffer>,
 205    pub position: language::Anchor,
 206    pub local_prompt: Result<String, String>,
 207    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 208}
 209
 210#[derive(Debug)]
 211pub struct ZetaSearchQueryDebugInfo {
 212    pub project: Entity<Project>,
 213    pub timestamp: Instant,
 214    pub search_queries: Vec<SearchToolQuery>,
 215}
 216
 217pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 218
 219struct ZetaProject {
 220    syntax_index: Option<Entity<SyntaxIndex>>,
 221    events: VecDeque<Event>,
 222    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 223    current_prediction: Option<CurrentEditPrediction>,
 224    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 225    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 226    refresh_context_debounce_task: Option<Task<Option<()>>>,
 227    refresh_context_timestamp: Option<Instant>,
 228}
 229
 230#[derive(Debug, Clone)]
 231struct CurrentEditPrediction {
 232    pub requested_by_buffer_id: EntityId,
 233    pub prediction: EditPrediction,
 234}
 235
 236impl CurrentEditPrediction {
 237    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 238        let Some(new_edits) = self
 239            .prediction
 240            .interpolate(&self.prediction.buffer.read(cx))
 241        else {
 242            return false;
 243        };
 244
 245        if self.prediction.buffer != old_prediction.prediction.buffer {
 246            return true;
 247        }
 248
 249        let Some(old_edits) = old_prediction
 250            .prediction
 251            .interpolate(&old_prediction.prediction.buffer.read(cx))
 252        else {
 253            return true;
 254        };
 255
 256        // This reduces the occurrence of UI thrash from replacing edits
 257        //
 258        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 259        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 260            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 261            && old_edits.len() == 1
 262            && new_edits.len() == 1
 263        {
 264            let (old_range, old_text) = &old_edits[0];
 265            let (new_range, new_text) = &new_edits[0];
 266            new_range == old_range && new_text.starts_with(old_text.as_ref())
 267        } else {
 268            true
 269        }
 270    }
 271}
 272
 273/// A prediction from the perspective of a buffer.
 274#[derive(Debug)]
 275enum BufferEditPrediction<'a> {
 276    Local { prediction: &'a EditPrediction },
 277    Jump { prediction: &'a EditPrediction },
 278}
 279
 280struct RegisteredBuffer {
 281    snapshot: BufferSnapshot,
 282    _subscriptions: [gpui::Subscription; 2],
 283}
 284
 285#[derive(Clone)]
 286pub enum Event {
 287    BufferChange {
 288        old_snapshot: BufferSnapshot,
 289        new_snapshot: BufferSnapshot,
 290        timestamp: Instant,
 291    },
 292}
 293
 294impl Event {
 295    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 296        match self {
 297            Event::BufferChange {
 298                old_snapshot,
 299                new_snapshot,
 300                ..
 301            } => {
 302                let path = new_snapshot.file().map(|f| f.full_path(cx));
 303
 304                let old_path = old_snapshot.file().and_then(|f| {
 305                    let old_path = f.full_path(cx);
 306                    if Some(&old_path) != path.as_ref() {
 307                        Some(old_path)
 308                    } else {
 309                        None
 310                    }
 311                });
 312
 313                // TODO [zeta2] move to bg?
 314                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 315
 316                if path == old_path && diff.is_empty() {
 317                    None
 318                } else {
 319                    Some(predict_edits_v3::Event::BufferChange {
 320                        old_path,
 321                        path,
 322                        diff,
 323                        //todo: Actually detect if this edit was predicted or not
 324                        predicted: false,
 325                    })
 326                }
 327            }
 328        }
 329    }
 330}
 331
 332impl Zeta {
 333    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 334        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 335    }
 336
 337    pub fn global(
 338        client: &Arc<Client>,
 339        user_store: &Entity<UserStore>,
 340        cx: &mut App,
 341    ) -> Entity<Self> {
 342        cx.try_global::<ZetaGlobal>()
 343            .map(|global| global.0.clone())
 344            .unwrap_or_else(|| {
 345                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 346                cx.set_global(ZetaGlobal(zeta.clone()));
 347                zeta
 348            })
 349    }
 350
 351    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 352        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 353
 354        Self {
 355            projects: HashMap::default(),
 356            client,
 357            user_store,
 358            options: DEFAULT_OPTIONS,
 359            llm_token: LlmApiToken::default(),
 360            _llm_token_subscription: cx.subscribe(
 361                &refresh_llm_token_listener,
 362                |this, _listener, _event, cx| {
 363                    let client = this.client.clone();
 364                    let llm_token = this.llm_token.clone();
 365                    cx.spawn(async move |_this, _cx| {
 366                        llm_token.refresh(&client).await?;
 367                        anyhow::Ok(())
 368                    })
 369                    .detach_and_log_err(cx);
 370                },
 371            ),
 372            update_required: false,
 373            debug_tx: None,
 374            #[cfg(feature = "eval-support")]
 375            eval_cache: None,
 376        }
 377    }
 378
 379    #[cfg(feature = "eval-support")]
 380    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 381        self.eval_cache = Some(cache);
 382    }
 383
 384    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 385        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 386        self.debug_tx = Some(debug_watch_tx);
 387        debug_watch_rx
 388    }
 389
 390    pub fn options(&self) -> &ZetaOptions {
 391        &self.options
 392    }
 393
 394    pub fn set_options(&mut self, options: ZetaOptions) {
 395        self.options = options;
 396    }
 397
 398    pub fn clear_history(&mut self) {
 399        for zeta_project in self.projects.values_mut() {
 400            zeta_project.events.clear();
 401        }
 402    }
 403
 404    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 405        self.projects
 406            .get(&project.entity_id())
 407            .map(|project| project.events.iter())
 408            .into_iter()
 409            .flatten()
 410    }
 411
 412    pub fn context_for_project(
 413        &self,
 414        project: &Entity<Project>,
 415    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 416        self.projects
 417            .get(&project.entity_id())
 418            .and_then(|project| {
 419                Some(
 420                    project
 421                        .context
 422                        .as_ref()?
 423                        .iter()
 424                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 425                )
 426            })
 427            .into_iter()
 428            .flatten()
 429    }
 430
 431    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 432        self.user_store.read(cx).edit_prediction_usage()
 433    }
 434
 435    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 436        self.get_or_init_zeta_project(project, cx);
 437    }
 438
 439    pub fn register_buffer(
 440        &mut self,
 441        buffer: &Entity<Buffer>,
 442        project: &Entity<Project>,
 443        cx: &mut Context<Self>,
 444    ) {
 445        let zeta_project = self.get_or_init_zeta_project(project, cx);
 446        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 447    }
 448
 449    fn get_or_init_zeta_project(
 450        &mut self,
 451        project: &Entity<Project>,
 452        cx: &mut App,
 453    ) -> &mut ZetaProject {
 454        self.projects
 455            .entry(project.entity_id())
 456            .or_insert_with(|| ZetaProject {
 457                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 458                    Some(cx.new(|cx| {
 459                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 460                    }))
 461                } else {
 462                    None
 463                },
 464                events: VecDeque::new(),
 465                registered_buffers: HashMap::default(),
 466                current_prediction: None,
 467                context: None,
 468                refresh_context_task: None,
 469                refresh_context_debounce_task: None,
 470                refresh_context_timestamp: None,
 471            })
 472    }
 473
 474    fn register_buffer_impl<'a>(
 475        zeta_project: &'a mut ZetaProject,
 476        buffer: &Entity<Buffer>,
 477        project: &Entity<Project>,
 478        cx: &mut Context<Self>,
 479    ) -> &'a mut RegisteredBuffer {
 480        let buffer_id = buffer.entity_id();
 481        match zeta_project.registered_buffers.entry(buffer_id) {
 482            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 483            hash_map::Entry::Vacant(entry) => {
 484                let snapshot = buffer.read(cx).snapshot();
 485                let project_entity_id = project.entity_id();
 486                entry.insert(RegisteredBuffer {
 487                    snapshot,
 488                    _subscriptions: [
 489                        cx.subscribe(buffer, {
 490                            let project = project.downgrade();
 491                            move |this, buffer, event, cx| {
 492                                if let language::BufferEvent::Edited = event
 493                                    && let Some(project) = project.upgrade()
 494                                {
 495                                    this.report_changes_for_buffer(&buffer, &project, cx);
 496                                }
 497                            }
 498                        }),
 499                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 500                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 501                            else {
 502                                return;
 503                            };
 504                            zeta_project.registered_buffers.remove(&buffer_id);
 505                        }),
 506                    ],
 507                })
 508            }
 509        }
 510    }
 511
 512    fn report_changes_for_buffer(
 513        &mut self,
 514        buffer: &Entity<Buffer>,
 515        project: &Entity<Project>,
 516        cx: &mut Context<Self>,
 517    ) -> BufferSnapshot {
 518        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 519        let zeta_project = self.get_or_init_zeta_project(project, cx);
 520        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 521
 522        let new_snapshot = buffer.read(cx).snapshot();
 523        if new_snapshot.version != registered_buffer.snapshot.version {
 524            let old_snapshot =
 525                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 526            Self::push_event(
 527                zeta_project,
 528                buffer_change_grouping_interval,
 529                Event::BufferChange {
 530                    old_snapshot,
 531                    new_snapshot: new_snapshot.clone(),
 532                    timestamp: Instant::now(),
 533                },
 534            );
 535        }
 536
 537        new_snapshot
 538    }
 539
 540    fn push_event(
 541        zeta_project: &mut ZetaProject,
 542        buffer_change_grouping_interval: Duration,
 543        event: Event,
 544    ) {
 545        let events = &mut zeta_project.events;
 546
 547        if buffer_change_grouping_interval > Duration::ZERO
 548            && let Some(Event::BufferChange {
 549                new_snapshot: last_new_snapshot,
 550                timestamp: last_timestamp,
 551                ..
 552            }) = events.back_mut()
 553        {
 554            // Coalesce edits for the same buffer when they happen one after the other.
 555            let Event::BufferChange {
 556                old_snapshot,
 557                new_snapshot,
 558                timestamp,
 559            } = &event;
 560
 561            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 562                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 563                && old_snapshot.version == last_new_snapshot.version
 564            {
 565                *last_new_snapshot = new_snapshot.clone();
 566                *last_timestamp = *timestamp;
 567                return;
 568            }
 569        }
 570
 571        if events.len() >= MAX_EVENT_COUNT {
 572            // These are halved instead of popping to improve prompt caching.
 573            events.drain(..MAX_EVENT_COUNT / 2);
 574        }
 575
 576        events.push_back(event);
 577    }
 578
 579    fn current_prediction_for_buffer(
 580        &self,
 581        buffer: &Entity<Buffer>,
 582        project: &Entity<Project>,
 583        cx: &App,
 584    ) -> Option<BufferEditPrediction<'_>> {
 585        let project_state = self.projects.get(&project.entity_id())?;
 586
 587        let CurrentEditPrediction {
 588            requested_by_buffer_id,
 589            prediction,
 590        } = project_state.current_prediction.as_ref()?;
 591
 592        if prediction.targets_buffer(buffer.read(cx)) {
 593            Some(BufferEditPrediction::Local { prediction })
 594        } else if *requested_by_buffer_id == buffer.entity_id() {
 595            Some(BufferEditPrediction::Jump { prediction })
 596        } else {
 597            None
 598        }
 599    }
 600
 601    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 602        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 603            return;
 604        };
 605
 606        let Some(prediction) = project_state.current_prediction.take() else {
 607            return;
 608        };
 609        let request_id = prediction.prediction.id.to_string();
 610
 611        let client = self.client.clone();
 612        let llm_token = self.llm_token.clone();
 613        let app_version = AppVersion::global(cx);
 614        cx.spawn(async move |this, cx| {
 615            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 616                http_client::Url::parse(&predict_edits_url)?
 617            } else {
 618                client
 619                    .http_client()
 620                    .build_zed_llm_url("/predict_edits/accept", &[])?
 621            };
 622
 623            let response = cx
 624                .background_spawn(Self::send_api_request::<()>(
 625                    move |builder| {
 626                        let req = builder.uri(url.as_ref()).body(
 627                            serde_json::to_string(&AcceptEditPredictionBody {
 628                                request_id: request_id.clone(),
 629                            })?
 630                            .into(),
 631                        );
 632                        Ok(req?)
 633                    },
 634                    client,
 635                    llm_token,
 636                    app_version,
 637                ))
 638                .await;
 639
 640            Self::handle_api_response(&this, response, cx)?;
 641            anyhow::Ok(())
 642        })
 643        .detach_and_log_err(cx);
 644    }
 645
 646    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 647        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 648            project_state.current_prediction.take();
 649        };
 650    }
 651
 652    pub fn refresh_prediction(
 653        &mut self,
 654        project: &Entity<Project>,
 655        buffer: &Entity<Buffer>,
 656        position: language::Anchor,
 657        cx: &mut Context<Self>,
 658    ) -> Task<Result<()>> {
 659        let request_task = self.request_prediction(project, buffer, position, cx);
 660        let buffer = buffer.clone();
 661        let project = project.clone();
 662
 663        cx.spawn(async move |this, cx| {
 664            if let Some(prediction) = request_task.await? {
 665                this.update(cx, |this, cx| {
 666                    let project_state = this
 667                        .projects
 668                        .get_mut(&project.entity_id())
 669                        .context("Project not found")?;
 670
 671                    let new_prediction = CurrentEditPrediction {
 672                        requested_by_buffer_id: buffer.entity_id(),
 673                        prediction: prediction,
 674                    };
 675
 676                    if project_state
 677                        .current_prediction
 678                        .as_ref()
 679                        .is_none_or(|old_prediction| {
 680                            new_prediction.should_replace_prediction(&old_prediction, cx)
 681                        })
 682                    {
 683                        project_state.current_prediction = Some(new_prediction);
 684                    }
 685                    anyhow::Ok(())
 686                })??;
 687            }
 688            Ok(())
 689        })
 690    }
 691
 692    pub fn request_prediction(
 693        &mut self,
 694        project: &Entity<Project>,
 695        active_buffer: &Entity<Buffer>,
 696        position: language::Anchor,
 697        cx: &mut Context<Self>,
 698    ) -> Task<Result<Option<EditPrediction>>> {
 699        let project_state = self.projects.get(&project.entity_id());
 700
 701        let index_state = project_state.and_then(|state| {
 702            state
 703                .syntax_index
 704                .as_ref()
 705                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
 706        });
 707        let options = self.options.clone();
 708        let active_snapshot = active_buffer.read(cx).snapshot();
 709        let Some(excerpt_path) = active_snapshot
 710            .file()
 711            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 712        else {
 713            return Task::ready(Err(anyhow!("No file path for excerpt")));
 714        };
 715        let client = self.client.clone();
 716        let llm_token = self.llm_token.clone();
 717        let app_version = AppVersion::global(cx);
 718        let worktree_snapshots = project
 719            .read(cx)
 720            .worktrees(cx)
 721            .map(|worktree| worktree.read(cx).snapshot())
 722            .collect::<Vec<_>>();
 723        let debug_tx = self.debug_tx.clone();
 724
 725        let events = project_state
 726            .map(|state| {
 727                state
 728                    .events
 729                    .iter()
 730                    .filter_map(|event| event.to_request_event(cx))
 731                    .collect::<Vec<_>>()
 732            })
 733            .unwrap_or_default();
 734
 735        let diagnostics = active_snapshot.diagnostic_sets().clone();
 736
 737        let parent_abs_path =
 738            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 739                let mut path = f.worktree.read(cx).absolutize(&f.path);
 740                if path.pop() { Some(path) } else { None }
 741            });
 742
 743        // TODO data collection
 744        let can_collect_data = cx.is_staff();
 745
 746        let empty_context_files = HashMap::default();
 747        let context_files = project_state
 748            .and_then(|project_state| project_state.context.as_ref())
 749            .unwrap_or(&empty_context_files);
 750
 751        #[cfg(feature = "eval-support")]
 752        let parsed_fut = futures::future::join_all(
 753            context_files
 754                .keys()
 755                .map(|buffer| buffer.read(cx).parsing_idle()),
 756        );
 757
 758        let mut included_files = context_files
 759            .iter()
 760            .filter_map(|(buffer_entity, ranges)| {
 761                let buffer = buffer_entity.read(cx);
 762                Some((
 763                    buffer_entity.clone(),
 764                    buffer.snapshot(),
 765                    buffer.file()?.full_path(cx).into(),
 766                    ranges.clone(),
 767                ))
 768            })
 769            .collect::<Vec<_>>();
 770
 771        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
 772            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
 773        });
 774
 775        #[cfg(feature = "eval-support")]
 776        let eval_cache = self.eval_cache.clone();
 777
 778        let request_task = cx.background_spawn({
 779            let active_buffer = active_buffer.clone();
 780            async move {
 781                #[cfg(feature = "eval-support")]
 782                parsed_fut.await;
 783
 784                let index_state = if let Some(index_state) = index_state {
 785                    Some(index_state.lock_owned().await)
 786                } else {
 787                    None
 788                };
 789
 790                let cursor_offset = position.to_offset(&active_snapshot);
 791                let cursor_point = cursor_offset.to_point(&active_snapshot);
 792
 793                let before_retrieval = chrono::Utc::now();
 794
 795                let (diagnostic_groups, diagnostic_groups_truncated) =
 796                    Self::gather_nearby_diagnostics(
 797                        cursor_offset,
 798                        &diagnostics,
 799                        &active_snapshot,
 800                        options.max_diagnostic_bytes,
 801                    );
 802
 803                let cloud_request = match options.context {
 804                    ContextMode::Agentic(context_options) => {
 805                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 806                            cursor_point,
 807                            &active_snapshot,
 808                            &context_options.excerpt,
 809                            index_state.as_deref(),
 810                        ) else {
 811                            return Ok((None, None));
 812                        };
 813
 814                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 815                            ..active_snapshot.anchor_before(excerpt.range.end);
 816
 817                        if let Some(buffer_ix) =
 818                            included_files.iter().position(|(_, snapshot, _, _)| {
 819                                snapshot.remote_id() == active_snapshot.remote_id()
 820                            })
 821                        {
 822                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 823                            let range_ix = ranges
 824                                .binary_search_by(|probe| {
 825                                    probe
 826                                        .start
 827                                        .cmp(&excerpt_anchor_range.start, buffer)
 828                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 829                                })
 830                                .unwrap_or_else(|ix| ix);
 831
 832                            ranges.insert(range_ix, excerpt_anchor_range);
 833                            let last_ix = included_files.len() - 1;
 834                            included_files.swap(buffer_ix, last_ix);
 835                        } else {
 836                            included_files.push((
 837                                active_buffer.clone(),
 838                                active_snapshot.clone(),
 839                                excerpt_path.clone(),
 840                                vec![excerpt_anchor_range],
 841                            ));
 842                        }
 843
 844                        let included_files = included_files
 845                            .iter()
 846                            .map(|(_, snapshot, path, ranges)| {
 847                                let excerpts = merge_excerpts(
 848                                    &snapshot,
 849                                    ranges.iter().map(|range| {
 850                                        let point_range = range.to_point(&snapshot);
 851                                        Line(point_range.start.row)..Line(point_range.end.row)
 852                                    }),
 853                                );
 854                                predict_edits_v3::IncludedFile {
 855                                    path: path.clone(),
 856                                    max_row: Line(snapshot.max_point().row),
 857                                    excerpts,
 858                                }
 859                            })
 860                            .collect::<Vec<_>>();
 861
 862                        predict_edits_v3::PredictEditsRequest {
 863                            excerpt_path,
 864                            excerpt: String::new(),
 865                            excerpt_line_range: Line(0)..Line(0),
 866                            excerpt_range: 0..0,
 867                            cursor_point: predict_edits_v3::Point {
 868                                line: predict_edits_v3::Line(cursor_point.row),
 869                                column: cursor_point.column,
 870                            },
 871                            included_files,
 872                            referenced_declarations: vec![],
 873                            events,
 874                            can_collect_data,
 875                            diagnostic_groups,
 876                            diagnostic_groups_truncated,
 877                            debug_info: debug_tx.is_some(),
 878                            prompt_max_bytes: Some(options.max_prompt_bytes),
 879                            prompt_format: options.prompt_format,
 880                            // TODO [zeta2]
 881                            signatures: vec![],
 882                            excerpt_parent: None,
 883                            git_info: None,
 884                        }
 885                    }
 886                    ContextMode::Syntax(context_options) => {
 887                        let Some(context) = EditPredictionContext::gather_context(
 888                            cursor_point,
 889                            &active_snapshot,
 890                            parent_abs_path.as_deref(),
 891                            &context_options,
 892                            index_state.as_deref(),
 893                        ) else {
 894                            return Ok((None, None));
 895                        };
 896
 897                        make_syntax_context_cloud_request(
 898                            excerpt_path,
 899                            context,
 900                            events,
 901                            can_collect_data,
 902                            diagnostic_groups,
 903                            diagnostic_groups_truncated,
 904                            None,
 905                            debug_tx.is_some(),
 906                            &worktree_snapshots,
 907                            index_state.as_deref(),
 908                            Some(options.max_prompt_bytes),
 909                            options.prompt_format,
 910                        )
 911                    }
 912                };
 913
 914                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 915
 916                let retrieval_time = chrono::Utc::now() - before_retrieval;
 917
 918                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 919                    let (response_tx, response_rx) = oneshot::channel();
 920
 921                    debug_tx
 922                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 923                            ZetaEditPredictionDebugInfo {
 924                                request: cloud_request.clone(),
 925                                retrieval_time,
 926                                buffer: active_buffer.downgrade(),
 927                                local_prompt: match prompt_result.as_ref() {
 928                                    Ok((prompt, _)) => Ok(prompt.clone()),
 929                                    Err(err) => Err(err.to_string()),
 930                                },
 931                                position,
 932                                response_rx,
 933                            },
 934                        ))
 935                        .ok();
 936                    Some(response_tx)
 937                } else {
 938                    None
 939                };
 940
 941                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 942                    if let Some(debug_response_tx) = debug_response_tx {
 943                        debug_response_tx
 944                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 945                            .ok();
 946                    }
 947                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 948                }
 949
 950                let (prompt, _) = prompt_result?;
 951                let request = open_ai::Request {
 952                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
 953                    messages: vec![open_ai::RequestMessage::User {
 954                        content: open_ai::MessageContent::Plain(prompt),
 955                    }],
 956                    stream: false,
 957                    max_completion_tokens: None,
 958                    stop: Default::default(),
 959                    temperature: 0.7,
 960                    tool_choice: None,
 961                    parallel_tool_calls: None,
 962                    tools: vec![],
 963                    prompt_cache_key: None,
 964                    reasoning_effort: None,
 965                };
 966
 967                log::trace!("Sending edit prediction request");
 968
 969                let before_request = chrono::Utc::now();
 970                let response = Self::send_raw_llm_request(
 971                    request,
 972                    client,
 973                    llm_token,
 974                    app_version,
 975                    #[cfg(feature = "eval-support")]
 976                    eval_cache,
 977                    #[cfg(feature = "eval-support")]
 978                    EvalCacheEntryKind::Prediction,
 979                )
 980                .await;
 981                let request_time = chrono::Utc::now() - before_request;
 982
 983                log::trace!("Got edit prediction response");
 984
 985                if let Some(debug_response_tx) = debug_response_tx {
 986                    debug_response_tx
 987                        .send((
 988                            response
 989                                .as_ref()
 990                                .map_err(|err| err.to_string())
 991                                .map(|response| response.0.clone()),
 992                            request_time,
 993                        ))
 994                        .ok();
 995                }
 996
 997                let (res, usage) = response?;
 998                let request_id = EditPredictionId(res.id.clone().into());
 999                let Some(mut output_text) = text_from_response(res) else {
1000                    return Ok((None, usage));
1001                };
1002
1003                if output_text.contains(CURSOR_MARKER) {
1004                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1005                    output_text = output_text.replace(CURSOR_MARKER, "");
1006                }
1007
1008                let get_buffer_from_context = |path: &Path| {
1009                    included_files
1010                        .iter()
1011                        .find_map(|(_, buffer, probe_path, ranges)| {
1012                            if probe_path.as_ref() == path {
1013                                Some((buffer, ranges.as_slice()))
1014                            } else {
1015                                None
1016                            }
1017                        })
1018                };
1019
1020                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1021                    PromptFormat::NumLinesUniDiff => {
1022                        // TODO: Implement parsing of multi-file diffs
1023                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1024                    }
1025                    PromptFormat::Minimal => {
1026                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1027                            let edits = vec![];
1028                            (&active_snapshot, edits)
1029                        } else {
1030                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1031                        }
1032                    }
1033                    PromptFormat::OldTextNewText => {
1034                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1035                            .await?
1036                    }
1037                    _ => {
1038                        bail!("unsupported prompt format {}", options.prompt_format)
1039                    }
1040                };
1041
1042                let edited_buffer = included_files
1043                    .iter()
1044                    .find_map(|(buffer, snapshot, _, _)| {
1045                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1046                            Some(buffer.clone())
1047                        } else {
1048                            None
1049                        }
1050                    })
1051                    .context("Failed to find buffer in included_buffers")?;
1052
1053                anyhow::Ok((
1054                    Some((
1055                        request_id,
1056                        edited_buffer,
1057                        edited_buffer_snapshot.clone(),
1058                        edits,
1059                    )),
1060                    usage,
1061                ))
1062            }
1063        });
1064
1065        cx.spawn({
1066            async move |this, cx| {
1067                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1068                    Self::handle_api_response(&this, request_task.await, cx)?
1069                else {
1070                    return Ok(None);
1071                };
1072
1073                // TODO telemetry: duration, etc
1074                Ok(
1075                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1076                        .await,
1077                )
1078            }
1079        })
1080    }
1081
1082    async fn send_raw_llm_request(
1083        request: open_ai::Request,
1084        client: Arc<Client>,
1085        llm_token: LlmApiToken,
1086        app_version: SemanticVersion,
1087        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1088        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1089    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1090        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1091            http_client::Url::parse(&predict_edits_url)?
1092        } else {
1093            client
1094                .http_client()
1095                .build_zed_llm_url("/predict_edits/raw", &[])?
1096        };
1097
1098        #[cfg(feature = "eval-support")]
1099        let cache_key = if let Some(cache) = eval_cache {
1100            use collections::FxHasher;
1101            use std::hash::{Hash, Hasher};
1102
1103            let mut hasher = FxHasher::default();
1104            url.hash(&mut hasher);
1105            let request_str = serde_json::to_string_pretty(&request)?;
1106            request_str.hash(&mut hasher);
1107            let hash = hasher.finish();
1108
1109            let key = (eval_cache_kind, hash);
1110            if let Some(response_str) = cache.read(key) {
1111                return Ok((serde_json::from_str(&response_str)?, None));
1112            }
1113
1114            Some((cache, request_str, key))
1115        } else {
1116            None
1117        };
1118
1119        let (response, usage) = Self::send_api_request(
1120            |builder| {
1121                let req = builder
1122                    .uri(url.as_ref())
1123                    .body(serde_json::to_string(&request)?.into());
1124                Ok(req?)
1125            },
1126            client,
1127            llm_token,
1128            app_version,
1129        )
1130        .await?;
1131
1132        #[cfg(feature = "eval-support")]
1133        if let Some((cache, request, key)) = cache_key {
1134            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1135        }
1136
1137        Ok((response, usage))
1138    }
1139
1140    fn handle_api_response<T>(
1141        this: &WeakEntity<Self>,
1142        response: Result<(T, Option<EditPredictionUsage>)>,
1143        cx: &mut gpui::AsyncApp,
1144    ) -> Result<T> {
1145        match response {
1146            Ok((data, usage)) => {
1147                if let Some(usage) = usage {
1148                    this.update(cx, |this, cx| {
1149                        this.user_store.update(cx, |user_store, cx| {
1150                            user_store.update_edit_prediction_usage(usage, cx);
1151                        });
1152                    })
1153                    .ok();
1154                }
1155                Ok(data)
1156            }
1157            Err(err) => {
1158                if err.is::<ZedUpdateRequiredError>() {
1159                    cx.update(|cx| {
1160                        this.update(cx, |this, _cx| {
1161                            this.update_required = true;
1162                        })
1163                        .ok();
1164
1165                        let error_message: SharedString = err.to_string().into();
1166                        show_app_notification(
1167                            NotificationId::unique::<ZedUpdateRequiredError>(),
1168                            cx,
1169                            move |cx| {
1170                                cx.new(|cx| {
1171                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1172                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1173                                })
1174                            },
1175                        );
1176                    })
1177                    .ok();
1178                }
1179                Err(err)
1180            }
1181        }
1182    }
1183
1184    async fn send_api_request<Res>(
1185        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1186        client: Arc<Client>,
1187        llm_token: LlmApiToken,
1188        app_version: SemanticVersion,
1189    ) -> Result<(Res, Option<EditPredictionUsage>)>
1190    where
1191        Res: DeserializeOwned,
1192    {
1193        let http_client = client.http_client();
1194        let mut token = llm_token.acquire(&client).await?;
1195        let mut did_retry = false;
1196
1197        loop {
1198            let request_builder = http_client::Request::builder().method(Method::POST);
1199
1200            let request = build(
1201                request_builder
1202                    .header("Content-Type", "application/json")
1203                    .header("Authorization", format!("Bearer {}", token))
1204                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1205            )?;
1206
1207            let mut response = http_client.send(request).await?;
1208
1209            if let Some(minimum_required_version) = response
1210                .headers()
1211                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1212                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1213            {
1214                anyhow::ensure!(
1215                    app_version >= minimum_required_version,
1216                    ZedUpdateRequiredError {
1217                        minimum_version: minimum_required_version
1218                    }
1219                );
1220            }
1221
1222            if response.status().is_success() {
1223                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1224
1225                let mut body = Vec::new();
1226                response.body_mut().read_to_end(&mut body).await?;
1227                return Ok((serde_json::from_slice(&body)?, usage));
1228            } else if !did_retry
1229                && response
1230                    .headers()
1231                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1232                    .is_some()
1233            {
1234                did_retry = true;
1235                token = llm_token.refresh(&client).await?;
1236            } else {
1237                let mut body = String::new();
1238                response.body_mut().read_to_string(&mut body).await?;
1239                anyhow::bail!(
1240                    "Request failed with status: {:?}\nBody: {}",
1241                    response.status(),
1242                    body
1243                );
1244            }
1245        }
1246    }
1247
1248    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1249    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1250
1251    // Refresh the related excerpts when the user just beguns editing after
1252    // an idle period, and after they pause editing.
1253    fn refresh_context_if_needed(
1254        &mut self,
1255        project: &Entity<Project>,
1256        buffer: &Entity<language::Buffer>,
1257        cursor_position: language::Anchor,
1258        cx: &mut Context<Self>,
1259    ) {
1260        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1261            return;
1262        }
1263
1264        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1265            return;
1266        };
1267
1268        let now = Instant::now();
1269        let was_idle = zeta_project
1270            .refresh_context_timestamp
1271            .map_or(true, |timestamp| {
1272                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1273            });
1274        zeta_project.refresh_context_timestamp = Some(now);
1275        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1276            let buffer = buffer.clone();
1277            let project = project.clone();
1278            async move |this, cx| {
1279                if was_idle {
1280                    log::debug!("refetching edit prediction context after idle");
1281                } else {
1282                    cx.background_executor()
1283                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1284                        .await;
1285                    log::debug!("refetching edit prediction context after pause");
1286                }
1287                this.update(cx, |this, cx| {
1288                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1289
1290                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1291                        zeta_project.refresh_context_task = Some(task.log_err());
1292                    };
1293                })
1294                .ok()
1295            }
1296        }));
1297    }
1298
1299    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1300    // and avoid spawning more than one concurrent task.
1301    pub fn refresh_context(
1302        &mut self,
1303        project: Entity<Project>,
1304        buffer: Entity<language::Buffer>,
1305        cursor_position: language::Anchor,
1306        cx: &mut Context<Self>,
1307    ) -> Task<Result<()>> {
1308        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1309            return Task::ready(anyhow::Ok(()));
1310        };
1311
1312        let ContextMode::Agentic(options) = &self.options().context else {
1313            return Task::ready(anyhow::Ok(()));
1314        };
1315
1316        let snapshot = buffer.read(cx).snapshot();
1317        let cursor_point = cursor_position.to_point(&snapshot);
1318        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1319            cursor_point,
1320            &snapshot,
1321            &options.excerpt,
1322            None,
1323        ) else {
1324            return Task::ready(Ok(()));
1325        };
1326
1327        let app_version = AppVersion::global(cx);
1328        let client = self.client.clone();
1329        let llm_token = self.llm_token.clone();
1330        let debug_tx = self.debug_tx.clone();
1331        let current_file_path: Arc<Path> = snapshot
1332            .file()
1333            .map(|f| f.full_path(cx).into())
1334            .unwrap_or_else(|| Path::new("untitled").into());
1335
1336        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1337            predict_edits_v3::PlanContextRetrievalRequest {
1338                excerpt: cursor_excerpt.text(&snapshot).body,
1339                excerpt_path: current_file_path,
1340                excerpt_line_range: cursor_excerpt.line_range,
1341                cursor_file_max_row: Line(snapshot.max_point().row),
1342                events: zeta_project
1343                    .events
1344                    .iter()
1345                    .filter_map(|ev| ev.to_request_event(cx))
1346                    .collect(),
1347            },
1348        ) {
1349            Ok(prompt) => prompt,
1350            Err(err) => {
1351                return Task::ready(Err(err));
1352            }
1353        };
1354
1355        if let Some(debug_tx) = &debug_tx {
1356            debug_tx
1357                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1358                    ZetaContextRetrievalStartedDebugInfo {
1359                        project: project.clone(),
1360                        timestamp: Instant::now(),
1361                        search_prompt: prompt.clone(),
1362                    },
1363                ))
1364                .ok();
1365        }
1366
1367        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1368            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1369                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1370            );
1371
1372            let description = schema
1373                .get("description")
1374                .and_then(|description| description.as_str())
1375                .unwrap()
1376                .to_string();
1377
1378            (schema.into(), description)
1379        });
1380
1381        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1382
1383        let request = open_ai::Request {
1384            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1385            messages: vec![open_ai::RequestMessage::User {
1386                content: open_ai::MessageContent::Plain(prompt),
1387            }],
1388            stream: false,
1389            max_completion_tokens: None,
1390            stop: Default::default(),
1391            temperature: 0.7,
1392            tool_choice: None,
1393            parallel_tool_calls: None,
1394            tools: vec![open_ai::ToolDefinition::Function {
1395                function: FunctionDefinition {
1396                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1397                    description: Some(tool_description),
1398                    parameters: Some(tool_schema),
1399                },
1400            }],
1401            prompt_cache_key: None,
1402            reasoning_effort: None,
1403        };
1404
1405        #[cfg(feature = "eval-support")]
1406        let eval_cache = self.eval_cache.clone();
1407
1408        cx.spawn(async move |this, cx| {
1409            log::trace!("Sending search planning request");
1410            let response = Self::send_raw_llm_request(
1411                request,
1412                client,
1413                llm_token,
1414                app_version,
1415                #[cfg(feature = "eval-support")]
1416                eval_cache.clone(),
1417                #[cfg(feature = "eval-support")]
1418                EvalCacheEntryKind::Context,
1419            )
1420            .await;
1421            let mut response = Self::handle_api_response(&this, response, cx)?;
1422            log::trace!("Got search planning response");
1423
1424            let choice = response
1425                .choices
1426                .pop()
1427                .context("No choices in retrieval response")?;
1428            let open_ai::RequestMessage::Assistant {
1429                content: _,
1430                tool_calls,
1431            } = choice.message
1432            else {
1433                anyhow::bail!("Retrieval response didn't include an assistant message");
1434            };
1435
1436            let mut queries: Vec<SearchToolQuery> = Vec::new();
1437            for tool_call in tool_calls {
1438                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1439                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1440                    log::warn!(
1441                        "Context retrieval response tried to call an unknown tool: {}",
1442                        function.name
1443                    );
1444
1445                    continue;
1446                }
1447
1448                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1449                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1450                queries.extend(input.queries);
1451            }
1452
1453            if let Some(debug_tx) = &debug_tx {
1454                debug_tx
1455                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1456                        ZetaSearchQueryDebugInfo {
1457                            project: project.clone(),
1458                            timestamp: Instant::now(),
1459                            search_queries: queries.clone(),
1460                        },
1461                    ))
1462                    .ok();
1463            }
1464
1465            log::trace!("Running retrieval search: {queries:#?}");
1466
1467            let related_excerpts_result = retrieval_search::run_retrieval_searches(
1468                queries,
1469                project.clone(),
1470                #[cfg(feature = "eval-support")]
1471                eval_cache,
1472                cx,
1473            )
1474            .await;
1475
1476            log::trace!("Search queries executed");
1477
1478            if let Some(debug_tx) = &debug_tx {
1479                debug_tx
1480                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1481                        ZetaContextRetrievalDebugInfo {
1482                            project: project.clone(),
1483                            timestamp: Instant::now(),
1484                        },
1485                    ))
1486                    .ok();
1487            }
1488
1489            this.update(cx, |this, _cx| {
1490                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1491                    return Ok(());
1492                };
1493                zeta_project.refresh_context_task.take();
1494                if let Some(debug_tx) = &this.debug_tx {
1495                    debug_tx
1496                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1497                            ZetaContextRetrievalDebugInfo {
1498                                project,
1499                                timestamp: Instant::now(),
1500                            },
1501                        ))
1502                        .ok();
1503                }
1504                match related_excerpts_result {
1505                    Ok(excerpts) => {
1506                        zeta_project.context = Some(excerpts);
1507                        Ok(())
1508                    }
1509                    Err(error) => Err(error),
1510                }
1511            })?
1512        })
1513    }
1514
1515    pub fn set_context(
1516        &mut self,
1517        project: Entity<Project>,
1518        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1519    ) {
1520        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1521            zeta_project.context = Some(context);
1522        }
1523    }
1524
1525    fn gather_nearby_diagnostics(
1526        cursor_offset: usize,
1527        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1528        snapshot: &BufferSnapshot,
1529        max_diagnostics_bytes: usize,
1530    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1531        // TODO: Could make this more efficient
1532        let mut diagnostic_groups = Vec::new();
1533        for (language_server_id, diagnostics) in diagnostic_sets {
1534            let mut groups = Vec::new();
1535            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1536            diagnostic_groups.extend(
1537                groups
1538                    .into_iter()
1539                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1540            );
1541        }
1542
1543        // sort by proximity to cursor
1544        diagnostic_groups.sort_by_key(|group| {
1545            let range = &group.entries[group.primary_ix].range;
1546            if range.start >= cursor_offset {
1547                range.start - cursor_offset
1548            } else if cursor_offset >= range.end {
1549                cursor_offset - range.end
1550            } else {
1551                (cursor_offset - range.start).min(range.end - cursor_offset)
1552            }
1553        });
1554
1555        let mut results = Vec::new();
1556        let mut diagnostic_groups_truncated = false;
1557        let mut diagnostics_byte_count = 0;
1558        for group in diagnostic_groups {
1559            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1560            diagnostics_byte_count += raw_value.get().len();
1561            if diagnostics_byte_count > max_diagnostics_bytes {
1562                diagnostic_groups_truncated = true;
1563                break;
1564            }
1565            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1566        }
1567
1568        (results, diagnostic_groups_truncated)
1569    }
1570
1571    // TODO: Dedupe with similar code in request_prediction?
1572    pub fn cloud_request_for_zeta_cli(
1573        &mut self,
1574        project: &Entity<Project>,
1575        buffer: &Entity<Buffer>,
1576        position: language::Anchor,
1577        cx: &mut Context<Self>,
1578    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1579        let project_state = self.projects.get(&project.entity_id());
1580
1581        let index_state = project_state.and_then(|state| {
1582            state
1583                .syntax_index
1584                .as_ref()
1585                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1586        });
1587        let options = self.options.clone();
1588        let snapshot = buffer.read(cx).snapshot();
1589        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1590            return Task::ready(Err(anyhow!("No file path for excerpt")));
1591        };
1592        let worktree_snapshots = project
1593            .read(cx)
1594            .worktrees(cx)
1595            .map(|worktree| worktree.read(cx).snapshot())
1596            .collect::<Vec<_>>();
1597
1598        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1599            let mut path = f.worktree.read(cx).absolutize(&f.path);
1600            if path.pop() { Some(path) } else { None }
1601        });
1602
1603        cx.background_spawn(async move {
1604            let index_state = if let Some(index_state) = index_state {
1605                Some(index_state.lock_owned().await)
1606            } else {
1607                None
1608            };
1609
1610            let cursor_point = position.to_point(&snapshot);
1611
1612            let debug_info = true;
1613            EditPredictionContext::gather_context(
1614                cursor_point,
1615                &snapshot,
1616                parent_abs_path.as_deref(),
1617                match &options.context {
1618                    ContextMode::Agentic(_) => {
1619                        // TODO
1620                        panic!("Llm mode not supported in zeta cli yet");
1621                    }
1622                    ContextMode::Syntax(edit_prediction_context_options) => {
1623                        edit_prediction_context_options
1624                    }
1625                },
1626                index_state.as_deref(),
1627            )
1628            .context("Failed to select excerpt")
1629            .map(|context| {
1630                make_syntax_context_cloud_request(
1631                    excerpt_path.into(),
1632                    context,
1633                    // TODO pass everything
1634                    Vec::new(),
1635                    false,
1636                    Vec::new(),
1637                    false,
1638                    None,
1639                    debug_info,
1640                    &worktree_snapshots,
1641                    index_state.as_deref(),
1642                    Some(options.max_prompt_bytes),
1643                    options.prompt_format,
1644                )
1645            })
1646        })
1647    }
1648
1649    pub fn wait_for_initial_indexing(
1650        &mut self,
1651        project: &Entity<Project>,
1652        cx: &mut App,
1653    ) -> Task<Result<()>> {
1654        let zeta_project = self.get_or_init_zeta_project(project, cx);
1655        if let Some(syntax_index) = &zeta_project.syntax_index {
1656            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1657        } else {
1658            Task::ready(Ok(()))
1659        }
1660    }
1661}
1662
1663pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1664    let choice = res.choices.pop()?;
1665    let output_text = match choice.message {
1666        open_ai::RequestMessage::Assistant {
1667            content: Some(open_ai::MessageContent::Plain(content)),
1668            ..
1669        } => content,
1670        open_ai::RequestMessage::Assistant {
1671            content: Some(open_ai::MessageContent::Multipart(mut content)),
1672            ..
1673        } => {
1674            if content.is_empty() {
1675                log::error!("No output from Baseten completion response");
1676                return None;
1677            }
1678
1679            match content.remove(0) {
1680                open_ai::MessagePart::Text { text } => text,
1681                open_ai::MessagePart::Image { .. } => {
1682                    log::error!("Expected text, got an image");
1683                    return None;
1684                }
1685            }
1686        }
1687        _ => {
1688            log::error!("Invalid response message: {:?}", choice.message);
1689            return None;
1690        }
1691    };
1692    Some(output_text)
1693}
1694
1695#[derive(Error, Debug)]
1696#[error(
1697    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1698)]
1699pub struct ZedUpdateRequiredError {
1700    minimum_version: SemanticVersion,
1701}
1702
1703fn make_syntax_context_cloud_request(
1704    excerpt_path: Arc<Path>,
1705    context: EditPredictionContext,
1706    events: Vec<predict_edits_v3::Event>,
1707    can_collect_data: bool,
1708    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1709    diagnostic_groups_truncated: bool,
1710    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1711    debug_info: bool,
1712    worktrees: &Vec<worktree::Snapshot>,
1713    index_state: Option<&SyntaxIndexState>,
1714    prompt_max_bytes: Option<usize>,
1715    prompt_format: PromptFormat,
1716) -> predict_edits_v3::PredictEditsRequest {
1717    let mut signatures = Vec::new();
1718    let mut declaration_to_signature_index = HashMap::default();
1719    let mut referenced_declarations = Vec::new();
1720
1721    for snippet in context.declarations {
1722        let project_entry_id = snippet.declaration.project_entry_id();
1723        let Some(path) = worktrees.iter().find_map(|worktree| {
1724            worktree.entry_for_id(project_entry_id).map(|entry| {
1725                let mut full_path = RelPathBuf::new();
1726                full_path.push(worktree.root_name());
1727                full_path.push(&entry.path);
1728                full_path
1729            })
1730        }) else {
1731            continue;
1732        };
1733
1734        let parent_index = index_state.and_then(|index_state| {
1735            snippet.declaration.parent().and_then(|parent| {
1736                add_signature(
1737                    parent,
1738                    &mut declaration_to_signature_index,
1739                    &mut signatures,
1740                    index_state,
1741                )
1742            })
1743        });
1744
1745        let (text, text_is_truncated) = snippet.declaration.item_text();
1746        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1747            path: path.as_std_path().into(),
1748            text: text.into(),
1749            range: snippet.declaration.item_line_range(),
1750            text_is_truncated,
1751            signature_range: snippet.declaration.signature_range_in_item_text(),
1752            parent_index,
1753            signature_score: snippet.score(DeclarationStyle::Signature),
1754            declaration_score: snippet.score(DeclarationStyle::Declaration),
1755            score_components: snippet.components,
1756        });
1757    }
1758
1759    let excerpt_parent = index_state.and_then(|index_state| {
1760        context
1761            .excerpt
1762            .parent_declarations
1763            .last()
1764            .and_then(|(parent, _)| {
1765                add_signature(
1766                    *parent,
1767                    &mut declaration_to_signature_index,
1768                    &mut signatures,
1769                    index_state,
1770                )
1771            })
1772    });
1773
1774    predict_edits_v3::PredictEditsRequest {
1775        excerpt_path,
1776        excerpt: context.excerpt_text.body,
1777        excerpt_line_range: context.excerpt.line_range,
1778        excerpt_range: context.excerpt.range,
1779        cursor_point: predict_edits_v3::Point {
1780            line: predict_edits_v3::Line(context.cursor_point.row),
1781            column: context.cursor_point.column,
1782        },
1783        referenced_declarations,
1784        included_files: vec![],
1785        signatures,
1786        excerpt_parent,
1787        events,
1788        can_collect_data,
1789        diagnostic_groups,
1790        diagnostic_groups_truncated,
1791        git_info,
1792        debug_info,
1793        prompt_max_bytes,
1794        prompt_format,
1795    }
1796}
1797
1798fn add_signature(
1799    declaration_id: DeclarationId,
1800    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1801    signatures: &mut Vec<Signature>,
1802    index: &SyntaxIndexState,
1803) -> Option<usize> {
1804    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1805        return Some(*signature_index);
1806    }
1807    let Some(parent_declaration) = index.declaration(declaration_id) else {
1808        log::error!("bug: missing parent declaration");
1809        return None;
1810    };
1811    let parent_index = parent_declaration.parent().and_then(|parent| {
1812        add_signature(parent, declaration_to_signature_index, signatures, index)
1813    });
1814    let (text, text_is_truncated) = parent_declaration.signature_text();
1815    let signature_index = signatures.len();
1816    signatures.push(Signature {
1817        text: text.into(),
1818        text_is_truncated,
1819        parent_index,
1820        range: parent_declaration.signature_line_range(),
1821    });
1822    declaration_to_signature_index.insert(declaration_id, signature_index);
1823    Some(signature_index)
1824}
1825
1826#[cfg(feature = "eval-support")]
1827pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1828
1829#[cfg(feature = "eval-support")]
1830#[derive(Debug, Clone, Copy, PartialEq)]
1831pub enum EvalCacheEntryKind {
1832    Context,
1833    Search,
1834    Prediction,
1835}
1836
1837#[cfg(feature = "eval-support")]
1838impl std::fmt::Display for EvalCacheEntryKind {
1839    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1840        match self {
1841            EvalCacheEntryKind::Search => write!(f, "search"),
1842            EvalCacheEntryKind::Context => write!(f, "context"),
1843            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1844        }
1845    }
1846}
1847
1848#[cfg(feature = "eval-support")]
1849pub trait EvalCache: Send + Sync {
1850    fn read(&self, key: EvalCacheKey) -> Option<String>;
1851    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1852}
1853
1854#[cfg(test)]
1855mod tests {
1856    use std::{path::Path, sync::Arc};
1857
1858    use client::UserStore;
1859    use clock::FakeSystemClock;
1860    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1861    use futures::{
1862        AsyncReadExt, StreamExt,
1863        channel::{mpsc, oneshot},
1864    };
1865    use gpui::{
1866        Entity, TestAppContext,
1867        http_client::{FakeHttpClient, Response},
1868        prelude::*,
1869    };
1870    use indoc::indoc;
1871    use language::OffsetRangeExt as _;
1872    use open_ai::Usage;
1873    use pretty_assertions::{assert_eq, assert_matches};
1874    use project::{FakeFs, Project};
1875    use serde_json::json;
1876    use settings::SettingsStore;
1877    use util::path;
1878    use uuid::Uuid;
1879
1880    use crate::{BufferEditPrediction, Zeta};
1881
1882    #[gpui::test]
1883    async fn test_current_state(cx: &mut TestAppContext) {
1884        let (zeta, mut req_rx) = init_test(cx);
1885        let fs = FakeFs::new(cx.executor());
1886        fs.insert_tree(
1887            "/root",
1888            json!({
1889                "1.txt": "Hello!\nHow\nBye\n",
1890                "2.txt": "Hola!\nComo\nAdios\n"
1891            }),
1892        )
1893        .await;
1894        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1895
1896        zeta.update(cx, |zeta, cx| {
1897            zeta.register_project(&project, cx);
1898        });
1899
1900        let buffer1 = project
1901            .update(cx, |project, cx| {
1902                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1903                project.open_buffer(path, cx)
1904            })
1905            .await
1906            .unwrap();
1907        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1908        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1909
1910        // Prediction for current file
1911
1912        let prediction_task = zeta.update(cx, |zeta, cx| {
1913            zeta.refresh_prediction(&project, &buffer1, position, cx)
1914        });
1915        let (_request, respond_tx) = req_rx.next().await.unwrap();
1916
1917        respond_tx
1918            .send(model_response(indoc! {r"
1919                --- a/root/1.txt
1920                +++ b/root/1.txt
1921                @@ ... @@
1922                 Hello!
1923                -How
1924                +How are you?
1925                 Bye
1926            "}))
1927            .unwrap();
1928        prediction_task.await.unwrap();
1929
1930        zeta.read_with(cx, |zeta, cx| {
1931            let prediction = zeta
1932                .current_prediction_for_buffer(&buffer1, &project, cx)
1933                .unwrap();
1934            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1935        });
1936
1937        // Context refresh
1938        let refresh_task = zeta.update(cx, |zeta, cx| {
1939            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1940        });
1941        let (_request, respond_tx) = req_rx.next().await.unwrap();
1942        respond_tx
1943            .send(open_ai::Response {
1944                id: Uuid::new_v4().to_string(),
1945                object: "response".into(),
1946                created: 0,
1947                model: "model".into(),
1948                choices: vec![open_ai::Choice {
1949                    index: 0,
1950                    message: open_ai::RequestMessage::Assistant {
1951                        content: None,
1952                        tool_calls: vec![open_ai::ToolCall {
1953                            id: "search".into(),
1954                            content: open_ai::ToolCallContent::Function {
1955                                function: open_ai::FunctionContent {
1956                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1957                                        .to_string(),
1958                                    arguments: serde_json::to_string(&SearchToolInput {
1959                                        queries: Box::new([SearchToolQuery {
1960                                            glob: "root/2.txt".to_string(),
1961                                            syntax_node: vec![],
1962                                            content: Some(".".into()),
1963                                        }]),
1964                                    })
1965                                    .unwrap(),
1966                                },
1967                            },
1968                        }],
1969                    },
1970                    finish_reason: None,
1971                }],
1972                usage: Usage {
1973                    prompt_tokens: 0,
1974                    completion_tokens: 0,
1975                    total_tokens: 0,
1976                },
1977            })
1978            .unwrap();
1979        refresh_task.await.unwrap();
1980
1981        zeta.update(cx, |zeta, _cx| {
1982            zeta.discard_current_prediction(&project);
1983        });
1984
1985        // Prediction for another file
1986        let prediction_task = zeta.update(cx, |zeta, cx| {
1987            zeta.refresh_prediction(&project, &buffer1, position, cx)
1988        });
1989        let (_request, respond_tx) = req_rx.next().await.unwrap();
1990        respond_tx
1991            .send(model_response(indoc! {r#"
1992                --- a/root/2.txt
1993                +++ b/root/2.txt
1994                 Hola!
1995                -Como
1996                +Como estas?
1997                 Adios
1998            "#}))
1999            .unwrap();
2000        prediction_task.await.unwrap();
2001        zeta.read_with(cx, |zeta, cx| {
2002            let prediction = zeta
2003                .current_prediction_for_buffer(&buffer1, &project, cx)
2004                .unwrap();
2005            assert_matches!(
2006                prediction,
2007                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2008            );
2009        });
2010
2011        let buffer2 = project
2012            .update(cx, |project, cx| {
2013                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2014                project.open_buffer(path, cx)
2015            })
2016            .await
2017            .unwrap();
2018
2019        zeta.read_with(cx, |zeta, cx| {
2020            let prediction = zeta
2021                .current_prediction_for_buffer(&buffer2, &project, cx)
2022                .unwrap();
2023            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2024        });
2025    }
2026
2027    #[gpui::test]
2028    async fn test_simple_request(cx: &mut TestAppContext) {
2029        let (zeta, mut req_rx) = init_test(cx);
2030        let fs = FakeFs::new(cx.executor());
2031        fs.insert_tree(
2032            "/root",
2033            json!({
2034                "foo.md":  "Hello!\nHow\nBye\n"
2035            }),
2036        )
2037        .await;
2038        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2039
2040        let buffer = project
2041            .update(cx, |project, cx| {
2042                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2043                project.open_buffer(path, cx)
2044            })
2045            .await
2046            .unwrap();
2047        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2048        let position = snapshot.anchor_before(language::Point::new(1, 3));
2049
2050        let prediction_task = zeta.update(cx, |zeta, cx| {
2051            zeta.request_prediction(&project, &buffer, position, cx)
2052        });
2053
2054        let (_, respond_tx) = req_rx.next().await.unwrap();
2055
2056        // TODO Put back when we have a structured request again
2057        // assert_eq!(
2058        //     request.excerpt_path.as_ref(),
2059        //     Path::new(path!("root/foo.md"))
2060        // );
2061        // assert_eq!(
2062        //     request.cursor_point,
2063        //     Point {
2064        //         line: Line(1),
2065        //         column: 3
2066        //     }
2067        // );
2068
2069        respond_tx
2070            .send(model_response(indoc! { r"
2071                --- a/root/foo.md
2072                +++ b/root/foo.md
2073                @@ ... @@
2074                 Hello!
2075                -How
2076                +How are you?
2077                 Bye
2078            "}))
2079            .unwrap();
2080
2081        let prediction = prediction_task.await.unwrap().unwrap();
2082
2083        assert_eq!(prediction.edits.len(), 1);
2084        assert_eq!(
2085            prediction.edits[0].0.to_point(&snapshot).start,
2086            language::Point::new(1, 3)
2087        );
2088        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2089    }
2090
2091    #[gpui::test]
2092    async fn test_request_events(cx: &mut TestAppContext) {
2093        let (zeta, mut req_rx) = init_test(cx);
2094        let fs = FakeFs::new(cx.executor());
2095        fs.insert_tree(
2096            "/root",
2097            json!({
2098                "foo.md": "Hello!\n\nBye\n"
2099            }),
2100        )
2101        .await;
2102        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2103
2104        let buffer = project
2105            .update(cx, |project, cx| {
2106                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2107                project.open_buffer(path, cx)
2108            })
2109            .await
2110            .unwrap();
2111
2112        zeta.update(cx, |zeta, cx| {
2113            zeta.register_buffer(&buffer, &project, cx);
2114        });
2115
2116        buffer.update(cx, |buffer, cx| {
2117            buffer.edit(vec![(7..7, "How")], None, cx);
2118        });
2119
2120        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2121        let position = snapshot.anchor_before(language::Point::new(1, 3));
2122
2123        let prediction_task = zeta.update(cx, |zeta, cx| {
2124            zeta.request_prediction(&project, &buffer, position, cx)
2125        });
2126
2127        let (request, respond_tx) = req_rx.next().await.unwrap();
2128
2129        let prompt = prompt_from_request(&request);
2130        assert!(
2131            prompt.contains(indoc! {"
2132            --- a/root/foo.md
2133            +++ b/root/foo.md
2134            @@ -1,3 +1,3 @@
2135             Hello!
2136            -
2137            +How
2138             Bye
2139        "}),
2140            "{prompt}"
2141        );
2142
2143        respond_tx
2144            .send(model_response(indoc! {r#"
2145                --- a/root/foo.md
2146                +++ b/root/foo.md
2147                @@ ... @@
2148                 Hello!
2149                -How
2150                +How are you?
2151                 Bye
2152            "#}))
2153            .unwrap();
2154
2155        let prediction = prediction_task.await.unwrap().unwrap();
2156
2157        assert_eq!(prediction.edits.len(), 1);
2158        assert_eq!(
2159            prediction.edits[0].0.to_point(&snapshot).start,
2160            language::Point::new(1, 3)
2161        );
2162        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2163    }
2164
2165    // Skipped until we start including diagnostics in prompt
2166    // #[gpui::test]
2167    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2168    //     let (zeta, mut req_rx) = init_test(cx);
2169    //     let fs = FakeFs::new(cx.executor());
2170    //     fs.insert_tree(
2171    //         "/root",
2172    //         json!({
2173    //             "foo.md": "Hello!\nBye"
2174    //         }),
2175    //     )
2176    //     .await;
2177    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2178
2179    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2180    //     let diagnostic = lsp::Diagnostic {
2181    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2182    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2183    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2184    //         ..Default::default()
2185    //     };
2186
2187    //     project.update(cx, |project, cx| {
2188    //         project.lsp_store().update(cx, |lsp_store, cx| {
2189    //             // Create some diagnostics
2190    //             lsp_store
2191    //                 .update_diagnostics(
2192    //                     LanguageServerId(0),
2193    //                     lsp::PublishDiagnosticsParams {
2194    //                         uri: path_to_buffer_uri.clone(),
2195    //                         diagnostics: vec![diagnostic],
2196    //                         version: None,
2197    //                     },
2198    //                     None,
2199    //                     language::DiagnosticSourceKind::Pushed,
2200    //                     &[],
2201    //                     cx,
2202    //                 )
2203    //                 .unwrap();
2204    //         });
2205    //     });
2206
2207    //     let buffer = project
2208    //         .update(cx, |project, cx| {
2209    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2210    //             project.open_buffer(path, cx)
2211    //         })
2212    //         .await
2213    //         .unwrap();
2214
2215    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2216    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2217
2218    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2219    //         zeta.request_prediction(&project, &buffer, position, cx)
2220    //     });
2221
2222    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2223
2224    //     assert_eq!(request.diagnostic_groups.len(), 1);
2225    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2226    //         .unwrap();
2227    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2228    //     assert_eq!(
2229    //         value,
2230    //         json!({
2231    //             "entries": [{
2232    //                 "range": {
2233    //                     "start": 8,
2234    //                     "end": 10
2235    //                 },
2236    //                 "diagnostic": {
2237    //                     "source": null,
2238    //                     "code": null,
2239    //                     "code_description": null,
2240    //                     "severity": 1,
2241    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2242    //                     "markdown": null,
2243    //                     "group_id": 0,
2244    //                     "is_primary": true,
2245    //                     "is_disk_based": false,
2246    //                     "is_unnecessary": false,
2247    //                     "source_kind": "Pushed",
2248    //                     "data": null,
2249    //                     "underline": true
2250    //                 }
2251    //             }],
2252    //             "primary_ix": 0
2253    //         })
2254    //     );
2255    // }
2256
2257    fn model_response(text: &str) -> open_ai::Response {
2258        open_ai::Response {
2259            id: Uuid::new_v4().to_string(),
2260            object: "response".into(),
2261            created: 0,
2262            model: "model".into(),
2263            choices: vec![open_ai::Choice {
2264                index: 0,
2265                message: open_ai::RequestMessage::Assistant {
2266                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2267                    tool_calls: vec![],
2268                },
2269                finish_reason: None,
2270            }],
2271            usage: Usage {
2272                prompt_tokens: 0,
2273                completion_tokens: 0,
2274                total_tokens: 0,
2275            },
2276        }
2277    }
2278
2279    fn prompt_from_request(request: &open_ai::Request) -> &str {
2280        assert_eq!(request.messages.len(), 1);
2281        let open_ai::RequestMessage::User {
2282            content: open_ai::MessageContent::Plain(content),
2283            ..
2284        } = &request.messages[0]
2285        else {
2286            panic!(
2287                "Request does not have single user message of type Plain. {:#?}",
2288                request
2289            );
2290        };
2291        content
2292    }
2293
2294    fn init_test(
2295        cx: &mut TestAppContext,
2296    ) -> (
2297        Entity<Zeta>,
2298        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2299    ) {
2300        cx.update(move |cx| {
2301            let settings_store = SettingsStore::test(cx);
2302            cx.set_global(settings_store);
2303            zlog::init_test();
2304
2305            let (req_tx, req_rx) = mpsc::unbounded();
2306
2307            let http_client = FakeHttpClient::create({
2308                move |req| {
2309                    let uri = req.uri().path().to_string();
2310                    let mut body = req.into_body();
2311                    let req_tx = req_tx.clone();
2312                    async move {
2313                        let resp = match uri.as_str() {
2314                            "/client/llm_tokens" => serde_json::to_string(&json!({
2315                                "token": "test"
2316                            }))
2317                            .unwrap(),
2318                            "/predict_edits/raw" => {
2319                                let mut buf = Vec::new();
2320                                body.read_to_end(&mut buf).await.ok();
2321                                let req = serde_json::from_slice(&buf).unwrap();
2322
2323                                let (res_tx, res_rx) = oneshot::channel();
2324                                req_tx.unbounded_send((req, res_tx)).unwrap();
2325                                serde_json::to_string(&res_rx.await?).unwrap()
2326                            }
2327                            _ => {
2328                                panic!("Unexpected path: {}", uri)
2329                            }
2330                        };
2331
2332                        Ok(Response::builder().body(resp.into()).unwrap())
2333                    }
2334                }
2335            });
2336
2337            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2338            client.cloud_client().set_credentials(1, "test".into());
2339
2340            language_model::init(client.clone(), cx);
2341
2342            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2343            let zeta = Zeta::global(&client, &user_store, cx);
2344
2345            (zeta, req_rx)
2346        })
2347    }
2348}