zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50mod xml_edits;
  51
  52use crate::merge_excerpts::merge_excerpts;
  53use crate::prediction::EditPrediction;
  54pub use crate::prediction::EditPredictionId;
  55pub use provider::ZetaEditPredictionProvider;
  56
  57/// Maximum number of events to track.
  58const MAX_EVENT_COUNT: usize = 16;
  59
  60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  61    max_bytes: 512,
  62    min_bytes: 128,
  63    target_before_cursor_over_total_bytes: 0.5,
  64};
  65
  66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  67    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  68
  69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  70    excerpt: DEFAULT_EXCERPT_OPTIONS,
  71};
  72
  73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  74    EditPredictionContextOptions {
  75        use_imports: true,
  76        max_retrieved_declarations: 0,
  77        excerpt: DEFAULT_EXCERPT_OPTIONS,
  78        score: EditPredictionScoreOptions {
  79            omit_excerpt_overlaps: true,
  80        },
  81    };
  82
  83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  84    context: DEFAULT_CONTEXT_OPTIONS,
  85    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  86    max_diagnostic_bytes: 2048,
  87    prompt_format: PromptFormat::DEFAULT,
  88    file_indexing_parallelism: 1,
  89    buffer_change_grouping_interval: Duration::from_secs(1),
  90};
  91
  92static USE_OLLAMA: LazyLock<bool> =
  93    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  94static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  95    env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
  96        "qwen3-coder:30b".to_string()
  97    } else {
  98        "yqvev8r3".to_string()
  99    })
 100});
 101static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 102    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 103        if *USE_OLLAMA {
 104            Some("http://localhost:11434/v1/chat/completions".into())
 105        } else {
 106            None
 107        }
 108    })
 109});
 110
 111pub struct Zeta2FeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115
 116    fn enabled_for_staff() -> bool {
 117        false
 118    }
 119}
 120
 121#[derive(Clone)]
 122struct ZetaGlobal(Entity<Zeta>);
 123
 124impl Global for ZetaGlobal {}
 125
 126pub struct Zeta {
 127    client: Arc<Client>,
 128    user_store: Entity<UserStore>,
 129    llm_token: LlmApiToken,
 130    _llm_token_subscription: Subscription,
 131    projects: HashMap<EntityId, ZetaProject>,
 132    options: ZetaOptions,
 133    update_required: bool,
 134    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 135    #[cfg(feature = "eval-support")]
 136    eval_cache: Option<Arc<dyn EvalCache>>,
 137}
 138
 139#[derive(Debug, Clone, PartialEq)]
 140pub struct ZetaOptions {
 141    pub context: ContextMode,
 142    pub max_prompt_bytes: usize,
 143    pub max_diagnostic_bytes: usize,
 144    pub prompt_format: predict_edits_v3::PromptFormat,
 145    pub file_indexing_parallelism: usize,
 146    pub buffer_change_grouping_interval: Duration,
 147}
 148
 149#[derive(Debug, Clone, PartialEq)]
 150pub enum ContextMode {
 151    Agentic(AgenticContextOptions),
 152    Syntax(EditPredictionContextOptions),
 153}
 154
 155#[derive(Debug, Clone, PartialEq)]
 156pub struct AgenticContextOptions {
 157    pub excerpt: EditPredictionExcerptOptions,
 158}
 159
 160impl ContextMode {
 161    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 162        match self {
 163            ContextMode::Agentic(options) => &options.excerpt,
 164            ContextMode::Syntax(options) => &options.excerpt,
 165        }
 166    }
 167}
 168
 169#[derive(Debug)]
 170pub enum ZetaDebugInfo {
 171    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 172    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 173    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 174    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 175    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 176}
 177
 178#[derive(Debug)]
 179pub struct ZetaContextRetrievalStartedDebugInfo {
 180    pub project: Entity<Project>,
 181    pub timestamp: Instant,
 182    pub search_prompt: String,
 183}
 184
 185#[derive(Debug)]
 186pub struct ZetaContextRetrievalDebugInfo {
 187    pub project: Entity<Project>,
 188    pub timestamp: Instant,
 189}
 190
 191#[derive(Debug)]
 192pub struct ZetaEditPredictionDebugInfo {
 193    pub request: predict_edits_v3::PredictEditsRequest,
 194    pub retrieval_time: TimeDelta,
 195    pub buffer: WeakEntity<Buffer>,
 196    pub position: language::Anchor,
 197    pub local_prompt: Result<String, String>,
 198    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 199}
 200
 201#[derive(Debug)]
 202pub struct ZetaSearchQueryDebugInfo {
 203    pub project: Entity<Project>,
 204    pub timestamp: Instant,
 205    pub search_queries: Vec<SearchToolQuery>,
 206}
 207
 208pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 209
 210struct ZetaProject {
 211    syntax_index: Option<Entity<SyntaxIndex>>,
 212    events: VecDeque<Event>,
 213    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 214    current_prediction: Option<CurrentEditPrediction>,
 215    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 216    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 217    refresh_context_debounce_task: Option<Task<Option<()>>>,
 218    refresh_context_timestamp: Option<Instant>,
 219}
 220
 221#[derive(Debug, Clone)]
 222struct CurrentEditPrediction {
 223    pub requested_by_buffer_id: EntityId,
 224    pub prediction: EditPrediction,
 225}
 226
 227impl CurrentEditPrediction {
 228    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 229        let Some(new_edits) = self
 230            .prediction
 231            .interpolate(&self.prediction.buffer.read(cx))
 232        else {
 233            return false;
 234        };
 235
 236        if self.prediction.buffer != old_prediction.prediction.buffer {
 237            return true;
 238        }
 239
 240        let Some(old_edits) = old_prediction
 241            .prediction
 242            .interpolate(&old_prediction.prediction.buffer.read(cx))
 243        else {
 244            return true;
 245        };
 246
 247        // This reduces the occurrence of UI thrash from replacing edits
 248        //
 249        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 250        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 251            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 252            && old_edits.len() == 1
 253            && new_edits.len() == 1
 254        {
 255            let (old_range, old_text) = &old_edits[0];
 256            let (new_range, new_text) = &new_edits[0];
 257            new_range == old_range && new_text.starts_with(old_text.as_ref())
 258        } else {
 259            true
 260        }
 261    }
 262}
 263
 264/// A prediction from the perspective of a buffer.
 265#[derive(Debug)]
 266enum BufferEditPrediction<'a> {
 267    Local { prediction: &'a EditPrediction },
 268    Jump { prediction: &'a EditPrediction },
 269}
 270
 271struct RegisteredBuffer {
 272    snapshot: BufferSnapshot,
 273    _subscriptions: [gpui::Subscription; 2],
 274}
 275
 276#[derive(Clone)]
 277pub enum Event {
 278    BufferChange {
 279        old_snapshot: BufferSnapshot,
 280        new_snapshot: BufferSnapshot,
 281        timestamp: Instant,
 282    },
 283}
 284
 285impl Event {
 286    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 287        match self {
 288            Event::BufferChange {
 289                old_snapshot,
 290                new_snapshot,
 291                ..
 292            } => {
 293                let path = new_snapshot.file().map(|f| f.full_path(cx));
 294
 295                let old_path = old_snapshot.file().and_then(|f| {
 296                    let old_path = f.full_path(cx);
 297                    if Some(&old_path) != path.as_ref() {
 298                        Some(old_path)
 299                    } else {
 300                        None
 301                    }
 302                });
 303
 304                // TODO [zeta2] move to bg?
 305                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 306
 307                if path == old_path && diff.is_empty() {
 308                    None
 309                } else {
 310                    Some(predict_edits_v3::Event::BufferChange {
 311                        old_path,
 312                        path,
 313                        diff,
 314                        //todo: Actually detect if this edit was predicted or not
 315                        predicted: false,
 316                    })
 317                }
 318            }
 319        }
 320    }
 321}
 322
 323impl Zeta {
 324    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 325        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 326    }
 327
 328    pub fn global(
 329        client: &Arc<Client>,
 330        user_store: &Entity<UserStore>,
 331        cx: &mut App,
 332    ) -> Entity<Self> {
 333        cx.try_global::<ZetaGlobal>()
 334            .map(|global| global.0.clone())
 335            .unwrap_or_else(|| {
 336                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 337                cx.set_global(ZetaGlobal(zeta.clone()));
 338                zeta
 339            })
 340    }
 341
 342    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 343        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 344
 345        Self {
 346            projects: HashMap::default(),
 347            client,
 348            user_store,
 349            options: DEFAULT_OPTIONS,
 350            llm_token: LlmApiToken::default(),
 351            _llm_token_subscription: cx.subscribe(
 352                &refresh_llm_token_listener,
 353                |this, _listener, _event, cx| {
 354                    let client = this.client.clone();
 355                    let llm_token = this.llm_token.clone();
 356                    cx.spawn(async move |_this, _cx| {
 357                        llm_token.refresh(&client).await?;
 358                        anyhow::Ok(())
 359                    })
 360                    .detach_and_log_err(cx);
 361                },
 362            ),
 363            update_required: false,
 364            debug_tx: None,
 365            #[cfg(feature = "eval-support")]
 366            eval_cache: None,
 367        }
 368    }
 369
 370    #[cfg(feature = "eval-support")]
 371    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 372        self.eval_cache = Some(cache);
 373    }
 374
 375    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 376        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 377        self.debug_tx = Some(debug_watch_tx);
 378        debug_watch_rx
 379    }
 380
 381    pub fn options(&self) -> &ZetaOptions {
 382        &self.options
 383    }
 384
 385    pub fn set_options(&mut self, options: ZetaOptions) {
 386        self.options = options;
 387    }
 388
 389    pub fn clear_history(&mut self) {
 390        for zeta_project in self.projects.values_mut() {
 391            zeta_project.events.clear();
 392        }
 393    }
 394
 395    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 396        self.projects
 397            .get(&project.entity_id())
 398            .map(|project| project.events.iter())
 399            .into_iter()
 400            .flatten()
 401    }
 402
 403    pub fn context_for_project(
 404        &self,
 405        project: &Entity<Project>,
 406    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 407        self.projects
 408            .get(&project.entity_id())
 409            .and_then(|project| {
 410                Some(
 411                    project
 412                        .context
 413                        .as_ref()?
 414                        .iter()
 415                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 416                )
 417            })
 418            .into_iter()
 419            .flatten()
 420    }
 421
 422    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 423        self.user_store.read(cx).edit_prediction_usage()
 424    }
 425
 426    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 427        self.get_or_init_zeta_project(project, cx);
 428    }
 429
 430    pub fn register_buffer(
 431        &mut self,
 432        buffer: &Entity<Buffer>,
 433        project: &Entity<Project>,
 434        cx: &mut Context<Self>,
 435    ) {
 436        let zeta_project = self.get_or_init_zeta_project(project, cx);
 437        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 438    }
 439
 440    fn get_or_init_zeta_project(
 441        &mut self,
 442        project: &Entity<Project>,
 443        cx: &mut App,
 444    ) -> &mut ZetaProject {
 445        self.projects
 446            .entry(project.entity_id())
 447            .or_insert_with(|| ZetaProject {
 448                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 449                    Some(cx.new(|cx| {
 450                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 451                    }))
 452                } else {
 453                    None
 454                },
 455                events: VecDeque::new(),
 456                registered_buffers: HashMap::default(),
 457                current_prediction: None,
 458                context: None,
 459                refresh_context_task: None,
 460                refresh_context_debounce_task: None,
 461                refresh_context_timestamp: None,
 462            })
 463    }
 464
 465    fn register_buffer_impl<'a>(
 466        zeta_project: &'a mut ZetaProject,
 467        buffer: &Entity<Buffer>,
 468        project: &Entity<Project>,
 469        cx: &mut Context<Self>,
 470    ) -> &'a mut RegisteredBuffer {
 471        let buffer_id = buffer.entity_id();
 472        match zeta_project.registered_buffers.entry(buffer_id) {
 473            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 474            hash_map::Entry::Vacant(entry) => {
 475                let snapshot = buffer.read(cx).snapshot();
 476                let project_entity_id = project.entity_id();
 477                entry.insert(RegisteredBuffer {
 478                    snapshot,
 479                    _subscriptions: [
 480                        cx.subscribe(buffer, {
 481                            let project = project.downgrade();
 482                            move |this, buffer, event, cx| {
 483                                if let language::BufferEvent::Edited = event
 484                                    && let Some(project) = project.upgrade()
 485                                {
 486                                    this.report_changes_for_buffer(&buffer, &project, cx);
 487                                }
 488                            }
 489                        }),
 490                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 491                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 492                            else {
 493                                return;
 494                            };
 495                            zeta_project.registered_buffers.remove(&buffer_id);
 496                        }),
 497                    ],
 498                })
 499            }
 500        }
 501    }
 502
 503    fn report_changes_for_buffer(
 504        &mut self,
 505        buffer: &Entity<Buffer>,
 506        project: &Entity<Project>,
 507        cx: &mut Context<Self>,
 508    ) -> BufferSnapshot {
 509        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 510        let zeta_project = self.get_or_init_zeta_project(project, cx);
 511        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 512
 513        let new_snapshot = buffer.read(cx).snapshot();
 514        if new_snapshot.version != registered_buffer.snapshot.version {
 515            let old_snapshot =
 516                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 517            Self::push_event(
 518                zeta_project,
 519                buffer_change_grouping_interval,
 520                Event::BufferChange {
 521                    old_snapshot,
 522                    new_snapshot: new_snapshot.clone(),
 523                    timestamp: Instant::now(),
 524                },
 525            );
 526        }
 527
 528        new_snapshot
 529    }
 530
 531    fn push_event(
 532        zeta_project: &mut ZetaProject,
 533        buffer_change_grouping_interval: Duration,
 534        event: Event,
 535    ) {
 536        let events = &mut zeta_project.events;
 537
 538        if buffer_change_grouping_interval > Duration::ZERO
 539            && let Some(Event::BufferChange {
 540                new_snapshot: last_new_snapshot,
 541                timestamp: last_timestamp,
 542                ..
 543            }) = events.back_mut()
 544        {
 545            // Coalesce edits for the same buffer when they happen one after the other.
 546            let Event::BufferChange {
 547                old_snapshot,
 548                new_snapshot,
 549                timestamp,
 550            } = &event;
 551
 552            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 553                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 554                && old_snapshot.version == last_new_snapshot.version
 555            {
 556                *last_new_snapshot = new_snapshot.clone();
 557                *last_timestamp = *timestamp;
 558                return;
 559            }
 560        }
 561
 562        if events.len() >= MAX_EVENT_COUNT {
 563            // These are halved instead of popping to improve prompt caching.
 564            events.drain(..MAX_EVENT_COUNT / 2);
 565        }
 566
 567        events.push_back(event);
 568    }
 569
 570    fn current_prediction_for_buffer(
 571        &self,
 572        buffer: &Entity<Buffer>,
 573        project: &Entity<Project>,
 574        cx: &App,
 575    ) -> Option<BufferEditPrediction<'_>> {
 576        let project_state = self.projects.get(&project.entity_id())?;
 577
 578        let CurrentEditPrediction {
 579            requested_by_buffer_id,
 580            prediction,
 581        } = project_state.current_prediction.as_ref()?;
 582
 583        if prediction.targets_buffer(buffer.read(cx)) {
 584            Some(BufferEditPrediction::Local { prediction })
 585        } else if *requested_by_buffer_id == buffer.entity_id() {
 586            Some(BufferEditPrediction::Jump { prediction })
 587        } else {
 588            None
 589        }
 590    }
 591
 592    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 593        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 594            return;
 595        };
 596
 597        let Some(prediction) = project_state.current_prediction.take() else {
 598            return;
 599        };
 600        let request_id = prediction.prediction.id.to_string();
 601
 602        let client = self.client.clone();
 603        let llm_token = self.llm_token.clone();
 604        let app_version = AppVersion::global(cx);
 605        cx.spawn(async move |this, cx| {
 606            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 607                http_client::Url::parse(&predict_edits_url)?
 608            } else {
 609                client
 610                    .http_client()
 611                    .build_zed_llm_url("/predict_edits/accept", &[])?
 612            };
 613
 614            let response = cx
 615                .background_spawn(Self::send_api_request::<()>(
 616                    move |builder| {
 617                        let req = builder.uri(url.as_ref()).body(
 618                            serde_json::to_string(&AcceptEditPredictionBody {
 619                                request_id: request_id.clone(),
 620                            })?
 621                            .into(),
 622                        );
 623                        Ok(req?)
 624                    },
 625                    client,
 626                    llm_token,
 627                    app_version,
 628                ))
 629                .await;
 630
 631            Self::handle_api_response(&this, response, cx)?;
 632            anyhow::Ok(())
 633        })
 634        .detach_and_log_err(cx);
 635    }
 636
 637    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 638        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 639            project_state.current_prediction.take();
 640        };
 641    }
 642
 643    pub fn refresh_prediction(
 644        &mut self,
 645        project: &Entity<Project>,
 646        buffer: &Entity<Buffer>,
 647        position: language::Anchor,
 648        cx: &mut Context<Self>,
 649    ) -> Task<Result<()>> {
 650        let request_task = self.request_prediction(project, buffer, position, cx);
 651        let buffer = buffer.clone();
 652        let project = project.clone();
 653
 654        cx.spawn(async move |this, cx| {
 655            if let Some(prediction) = request_task.await? {
 656                this.update(cx, |this, cx| {
 657                    let project_state = this
 658                        .projects
 659                        .get_mut(&project.entity_id())
 660                        .context("Project not found")?;
 661
 662                    let new_prediction = CurrentEditPrediction {
 663                        requested_by_buffer_id: buffer.entity_id(),
 664                        prediction: prediction,
 665                    };
 666
 667                    if project_state
 668                        .current_prediction
 669                        .as_ref()
 670                        .is_none_or(|old_prediction| {
 671                            new_prediction.should_replace_prediction(&old_prediction, cx)
 672                        })
 673                    {
 674                        project_state.current_prediction = Some(new_prediction);
 675                    }
 676                    anyhow::Ok(())
 677                })??;
 678            }
 679            Ok(())
 680        })
 681    }
 682
 683    pub fn request_prediction(
 684        &mut self,
 685        project: &Entity<Project>,
 686        active_buffer: &Entity<Buffer>,
 687        position: language::Anchor,
 688        cx: &mut Context<Self>,
 689    ) -> Task<Result<Option<EditPrediction>>> {
 690        let project_state = self.projects.get(&project.entity_id());
 691
 692        let index_state = project_state.and_then(|state| {
 693            state
 694                .syntax_index
 695                .as_ref()
 696                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
 697        });
 698        let options = self.options.clone();
 699        let active_snapshot = active_buffer.read(cx).snapshot();
 700        let Some(excerpt_path) = active_snapshot
 701            .file()
 702            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 703        else {
 704            return Task::ready(Err(anyhow!("No file path for excerpt")));
 705        };
 706        let client = self.client.clone();
 707        let llm_token = self.llm_token.clone();
 708        let app_version = AppVersion::global(cx);
 709        let worktree_snapshots = project
 710            .read(cx)
 711            .worktrees(cx)
 712            .map(|worktree| worktree.read(cx).snapshot())
 713            .collect::<Vec<_>>();
 714        let debug_tx = self.debug_tx.clone();
 715
 716        let events = project_state
 717            .map(|state| {
 718                state
 719                    .events
 720                    .iter()
 721                    .filter_map(|event| event.to_request_event(cx))
 722                    .collect::<Vec<_>>()
 723            })
 724            .unwrap_or_default();
 725
 726        let diagnostics = active_snapshot.diagnostic_sets().clone();
 727
 728        let parent_abs_path =
 729            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 730                let mut path = f.worktree.read(cx).absolutize(&f.path);
 731                if path.pop() { Some(path) } else { None }
 732            });
 733
 734        // TODO data collection
 735        let can_collect_data = cx.is_staff();
 736
 737        let empty_context_files = HashMap::default();
 738        let context_files = project_state
 739            .and_then(|project_state| project_state.context.as_ref())
 740            .unwrap_or(&empty_context_files);
 741
 742        #[cfg(feature = "eval-support")]
 743        let parsed_fut = futures::future::join_all(
 744            context_files
 745                .keys()
 746                .map(|buffer| buffer.read(cx).parsing_idle()),
 747        );
 748
 749        let mut included_files = context_files
 750            .iter()
 751            .filter_map(|(buffer_entity, ranges)| {
 752                let buffer = buffer_entity.read(cx);
 753                Some((
 754                    buffer_entity.clone(),
 755                    buffer.snapshot(),
 756                    buffer.file()?.full_path(cx).into(),
 757                    ranges.clone(),
 758                ))
 759            })
 760            .collect::<Vec<_>>();
 761
 762        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
 763            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
 764        });
 765
 766        #[cfg(feature = "eval-support")]
 767        let eval_cache = self.eval_cache.clone();
 768
 769        let request_task = cx.background_spawn({
 770            let active_buffer = active_buffer.clone();
 771            async move {
 772                #[cfg(feature = "eval-support")]
 773                parsed_fut.await;
 774
 775                let index_state = if let Some(index_state) = index_state {
 776                    Some(index_state.lock_owned().await)
 777                } else {
 778                    None
 779                };
 780
 781                let cursor_offset = position.to_offset(&active_snapshot);
 782                let cursor_point = cursor_offset.to_point(&active_snapshot);
 783
 784                let before_retrieval = chrono::Utc::now();
 785
 786                let (diagnostic_groups, diagnostic_groups_truncated) =
 787                    Self::gather_nearby_diagnostics(
 788                        cursor_offset,
 789                        &diagnostics,
 790                        &active_snapshot,
 791                        options.max_diagnostic_bytes,
 792                    );
 793
 794                let cloud_request = match options.context {
 795                    ContextMode::Agentic(context_options) => {
 796                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 797                            cursor_point,
 798                            &active_snapshot,
 799                            &context_options.excerpt,
 800                            index_state.as_deref(),
 801                        ) else {
 802                            return Ok((None, None));
 803                        };
 804
 805                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 806                            ..active_snapshot.anchor_before(excerpt.range.end);
 807
 808                        if let Some(buffer_ix) =
 809                            included_files.iter().position(|(_, snapshot, _, _)| {
 810                                snapshot.remote_id() == active_snapshot.remote_id()
 811                            })
 812                        {
 813                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 814                            let range_ix = ranges
 815                                .binary_search_by(|probe| {
 816                                    probe
 817                                        .start
 818                                        .cmp(&excerpt_anchor_range.start, buffer)
 819                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 820                                })
 821                                .unwrap_or_else(|ix| ix);
 822
 823                            ranges.insert(range_ix, excerpt_anchor_range);
 824                            let last_ix = included_files.len() - 1;
 825                            included_files.swap(buffer_ix, last_ix);
 826                        } else {
 827                            included_files.push((
 828                                active_buffer.clone(),
 829                                active_snapshot,
 830                                excerpt_path.clone(),
 831                                vec![excerpt_anchor_range],
 832                            ));
 833                        }
 834
 835                        let included_files = included_files
 836                            .iter()
 837                            .map(|(_, snapshot, path, ranges)| {
 838                                let excerpts = merge_excerpts(
 839                                    &snapshot,
 840                                    ranges.iter().map(|range| {
 841                                        let point_range = range.to_point(&snapshot);
 842                                        Line(point_range.start.row)..Line(point_range.end.row)
 843                                    }),
 844                                );
 845                                predict_edits_v3::IncludedFile {
 846                                    path: path.clone(),
 847                                    max_row: Line(snapshot.max_point().row),
 848                                    excerpts,
 849                                }
 850                            })
 851                            .collect::<Vec<_>>();
 852
 853                        predict_edits_v3::PredictEditsRequest {
 854                            excerpt_path,
 855                            excerpt: String::new(),
 856                            excerpt_line_range: Line(0)..Line(0),
 857                            excerpt_range: 0..0,
 858                            cursor_point: predict_edits_v3::Point {
 859                                line: predict_edits_v3::Line(cursor_point.row),
 860                                column: cursor_point.column,
 861                            },
 862                            included_files,
 863                            referenced_declarations: vec![],
 864                            events,
 865                            can_collect_data,
 866                            diagnostic_groups,
 867                            diagnostic_groups_truncated,
 868                            debug_info: debug_tx.is_some(),
 869                            prompt_max_bytes: Some(options.max_prompt_bytes),
 870                            prompt_format: options.prompt_format,
 871                            // TODO [zeta2]
 872                            signatures: vec![],
 873                            excerpt_parent: None,
 874                            git_info: None,
 875                        }
 876                    }
 877                    ContextMode::Syntax(context_options) => {
 878                        let Some(context) = EditPredictionContext::gather_context(
 879                            cursor_point,
 880                            &active_snapshot,
 881                            parent_abs_path.as_deref(),
 882                            &context_options,
 883                            index_state.as_deref(),
 884                        ) else {
 885                            return Ok((None, None));
 886                        };
 887
 888                        make_syntax_context_cloud_request(
 889                            excerpt_path,
 890                            context,
 891                            events,
 892                            can_collect_data,
 893                            diagnostic_groups,
 894                            diagnostic_groups_truncated,
 895                            None,
 896                            debug_tx.is_some(),
 897                            &worktree_snapshots,
 898                            index_state.as_deref(),
 899                            Some(options.max_prompt_bytes),
 900                            options.prompt_format,
 901                        )
 902                    }
 903                };
 904
 905                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 906
 907                let retrieval_time = chrono::Utc::now() - before_retrieval;
 908
 909                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 910                    let (response_tx, response_rx) = oneshot::channel();
 911
 912                    debug_tx
 913                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 914                            ZetaEditPredictionDebugInfo {
 915                                request: cloud_request.clone(),
 916                                retrieval_time,
 917                                buffer: active_buffer.downgrade(),
 918                                local_prompt: match prompt_result.as_ref() {
 919                                    Ok((prompt, _)) => Ok(prompt.clone()),
 920                                    Err(err) => Err(err.to_string()),
 921                                },
 922                                position,
 923                                response_rx,
 924                            },
 925                        ))
 926                        .ok();
 927                    Some(response_tx)
 928                } else {
 929                    None
 930                };
 931
 932                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 933                    if let Some(debug_response_tx) = debug_response_tx {
 934                        debug_response_tx
 935                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 936                            .ok();
 937                    }
 938                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 939                }
 940
 941                let (prompt, _) = prompt_result?;
 942                let request = open_ai::Request {
 943                    model: MODEL_ID.clone(),
 944                    messages: vec![open_ai::RequestMessage::User {
 945                        content: open_ai::MessageContent::Plain(prompt),
 946                    }],
 947                    stream: false,
 948                    max_completion_tokens: None,
 949                    stop: Default::default(),
 950                    temperature: 0.7,
 951                    tool_choice: None,
 952                    parallel_tool_calls: None,
 953                    tools: vec![],
 954                    prompt_cache_key: None,
 955                    reasoning_effort: None,
 956                };
 957
 958                log::trace!("Sending edit prediction request");
 959
 960                let before_request = chrono::Utc::now();
 961                let response = Self::send_raw_llm_request(
 962                    request,
 963                    client,
 964                    llm_token,
 965                    app_version,
 966                    #[cfg(feature = "eval-support")]
 967                    eval_cache,
 968                    #[cfg(feature = "eval-support")]
 969                    EvalCacheEntryKind::Prediction,
 970                )
 971                .await;
 972                let request_time = chrono::Utc::now() - before_request;
 973
 974                log::trace!("Got edit prediction response");
 975
 976                if let Some(debug_response_tx) = debug_response_tx {
 977                    debug_response_tx
 978                        .send((
 979                            response
 980                                .as_ref()
 981                                .map_err(|err| err.to_string())
 982                                .map(|response| response.0.clone()),
 983                            request_time,
 984                        ))
 985                        .ok();
 986                }
 987
 988                let (res, usage) = response?;
 989                let request_id = EditPredictionId(res.id.clone().into());
 990                let Some(mut output_text) = text_from_response(res) else {
 991                    return Ok((None, usage));
 992                };
 993
 994                if output_text.contains(CURSOR_MARKER) {
 995                    log::trace!("Stripping out {CURSOR_MARKER} from response");
 996                    output_text = output_text.replace(CURSOR_MARKER, "");
 997                }
 998
 999                let get_buffer_from_context = |path: &Path| {
1000                    included_files
1001                        .iter()
1002                        .find_map(|(_, buffer, probe_path, ranges)| {
1003                            if probe_path.as_ref() == path {
1004                                Some((buffer, ranges.as_slice()))
1005                            } else {
1006                                None
1007                            }
1008                        })
1009                };
1010
1011                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1012                    PromptFormat::NumLinesUniDiff => {
1013                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1014                    }
1015                    PromptFormat::OldTextNewText => {
1016                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1017                            .await?
1018                    }
1019                    _ => {
1020                        bail!("unsupported prompt format {}", options.prompt_format)
1021                    }
1022                };
1023
1024                let edited_buffer = included_files
1025                    .iter()
1026                    .find_map(|(buffer, snapshot, _, _)| {
1027                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1028                            Some(buffer.clone())
1029                        } else {
1030                            None
1031                        }
1032                    })
1033                    .context("Failed to find buffer in included_buffers")?;
1034
1035                anyhow::Ok((
1036                    Some((
1037                        request_id,
1038                        edited_buffer,
1039                        edited_buffer_snapshot.clone(),
1040                        edits,
1041                    )),
1042                    usage,
1043                ))
1044            }
1045        });
1046
1047        cx.spawn({
1048            async move |this, cx| {
1049                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1050                    Self::handle_api_response(&this, request_task.await, cx)?
1051                else {
1052                    return Ok(None);
1053                };
1054
1055                // TODO telemetry: duration, etc
1056                Ok(
1057                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1058                        .await,
1059                )
1060            }
1061        })
1062    }
1063
1064    async fn send_raw_llm_request(
1065        request: open_ai::Request,
1066        client: Arc<Client>,
1067        llm_token: LlmApiToken,
1068        app_version: SemanticVersion,
1069        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1070        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1071    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1072        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1073            http_client::Url::parse(&predict_edits_url)?
1074        } else {
1075            client
1076                .http_client()
1077                .build_zed_llm_url("/predict_edits/raw", &[])?
1078        };
1079
1080        #[cfg(feature = "eval-support")]
1081        let cache_key = if let Some(cache) = eval_cache {
1082            use collections::FxHasher;
1083            use std::hash::{Hash, Hasher};
1084
1085            let mut hasher = FxHasher::default();
1086            url.hash(&mut hasher);
1087            let request_str = serde_json::to_string_pretty(&request)?;
1088            request_str.hash(&mut hasher);
1089            let hash = hasher.finish();
1090
1091            let key = (eval_cache_kind, hash);
1092            if let Some(response_str) = cache.read(key) {
1093                return Ok((serde_json::from_str(&response_str)?, None));
1094            }
1095
1096            Some((cache, request_str, key))
1097        } else {
1098            None
1099        };
1100
1101        let (response, usage) = Self::send_api_request(
1102            |builder| {
1103                let req = builder
1104                    .uri(url.as_ref())
1105                    .body(serde_json::to_string(&request)?.into());
1106                Ok(req?)
1107            },
1108            client,
1109            llm_token,
1110            app_version,
1111        )
1112        .await?;
1113
1114        #[cfg(feature = "eval-support")]
1115        if let Some((cache, request, key)) = cache_key {
1116            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1117        }
1118
1119        Ok((response, usage))
1120    }
1121
1122    fn handle_api_response<T>(
1123        this: &WeakEntity<Self>,
1124        response: Result<(T, Option<EditPredictionUsage>)>,
1125        cx: &mut gpui::AsyncApp,
1126    ) -> Result<T> {
1127        match response {
1128            Ok((data, usage)) => {
1129                if let Some(usage) = usage {
1130                    this.update(cx, |this, cx| {
1131                        this.user_store.update(cx, |user_store, cx| {
1132                            user_store.update_edit_prediction_usage(usage, cx);
1133                        });
1134                    })
1135                    .ok();
1136                }
1137                Ok(data)
1138            }
1139            Err(err) => {
1140                if err.is::<ZedUpdateRequiredError>() {
1141                    cx.update(|cx| {
1142                        this.update(cx, |this, _cx| {
1143                            this.update_required = true;
1144                        })
1145                        .ok();
1146
1147                        let error_message: SharedString = err.to_string().into();
1148                        show_app_notification(
1149                            NotificationId::unique::<ZedUpdateRequiredError>(),
1150                            cx,
1151                            move |cx| {
1152                                cx.new(|cx| {
1153                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1154                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1155                                })
1156                            },
1157                        );
1158                    })
1159                    .ok();
1160                }
1161                Err(err)
1162            }
1163        }
1164    }
1165
1166    async fn send_api_request<Res>(
1167        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1168        client: Arc<Client>,
1169        llm_token: LlmApiToken,
1170        app_version: SemanticVersion,
1171    ) -> Result<(Res, Option<EditPredictionUsage>)>
1172    where
1173        Res: DeserializeOwned,
1174    {
1175        let http_client = client.http_client();
1176        let mut token = llm_token.acquire(&client).await?;
1177        let mut did_retry = false;
1178
1179        loop {
1180            let request_builder = http_client::Request::builder().method(Method::POST);
1181
1182            let request = build(
1183                request_builder
1184                    .header("Content-Type", "application/json")
1185                    .header("Authorization", format!("Bearer {}", token))
1186                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1187            )?;
1188
1189            let mut response = http_client.send(request).await?;
1190
1191            if let Some(minimum_required_version) = response
1192                .headers()
1193                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1194                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1195            {
1196                anyhow::ensure!(
1197                    app_version >= minimum_required_version,
1198                    ZedUpdateRequiredError {
1199                        minimum_version: minimum_required_version
1200                    }
1201                );
1202            }
1203
1204            if response.status().is_success() {
1205                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1206
1207                let mut body = Vec::new();
1208                response.body_mut().read_to_end(&mut body).await?;
1209                return Ok((serde_json::from_slice(&body)?, usage));
1210            } else if !did_retry
1211                && response
1212                    .headers()
1213                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1214                    .is_some()
1215            {
1216                did_retry = true;
1217                token = llm_token.refresh(&client).await?;
1218            } else {
1219                let mut body = String::new();
1220                response.body_mut().read_to_string(&mut body).await?;
1221                anyhow::bail!(
1222                    "Request failed with status: {:?}\nBody: {}",
1223                    response.status(),
1224                    body
1225                );
1226            }
1227        }
1228    }
1229
1230    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1231    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1232
1233    // Refresh the related excerpts when the user just beguns editing after
1234    // an idle period, and after they pause editing.
1235    fn refresh_context_if_needed(
1236        &mut self,
1237        project: &Entity<Project>,
1238        buffer: &Entity<language::Buffer>,
1239        cursor_position: language::Anchor,
1240        cx: &mut Context<Self>,
1241    ) {
1242        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1243            return;
1244        }
1245
1246        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1247            return;
1248        };
1249
1250        let now = Instant::now();
1251        let was_idle = zeta_project
1252            .refresh_context_timestamp
1253            .map_or(true, |timestamp| {
1254                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1255            });
1256        zeta_project.refresh_context_timestamp = Some(now);
1257        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1258            let buffer = buffer.clone();
1259            let project = project.clone();
1260            async move |this, cx| {
1261                if was_idle {
1262                    log::debug!("refetching edit prediction context after idle");
1263                } else {
1264                    cx.background_executor()
1265                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1266                        .await;
1267                    log::debug!("refetching edit prediction context after pause");
1268                }
1269                this.update(cx, |this, cx| {
1270                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1271
1272                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1273                        zeta_project.refresh_context_task = Some(task.log_err());
1274                    };
1275                })
1276                .ok()
1277            }
1278        }));
1279    }
1280
1281    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1282    // and avoid spawning more than one concurrent task.
1283    pub fn refresh_context(
1284        &mut self,
1285        project: Entity<Project>,
1286        buffer: Entity<language::Buffer>,
1287        cursor_position: language::Anchor,
1288        cx: &mut Context<Self>,
1289    ) -> Task<Result<()>> {
1290        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1291            return Task::ready(anyhow::Ok(()));
1292        };
1293
1294        let ContextMode::Agentic(options) = &self.options().context else {
1295            return Task::ready(anyhow::Ok(()));
1296        };
1297
1298        let snapshot = buffer.read(cx).snapshot();
1299        let cursor_point = cursor_position.to_point(&snapshot);
1300        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1301            cursor_point,
1302            &snapshot,
1303            &options.excerpt,
1304            None,
1305        ) else {
1306            return Task::ready(Ok(()));
1307        };
1308
1309        let app_version = AppVersion::global(cx);
1310        let client = self.client.clone();
1311        let llm_token = self.llm_token.clone();
1312        let debug_tx = self.debug_tx.clone();
1313        let current_file_path: Arc<Path> = snapshot
1314            .file()
1315            .map(|f| f.full_path(cx).into())
1316            .unwrap_or_else(|| Path::new("untitled").into());
1317
1318        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1319            predict_edits_v3::PlanContextRetrievalRequest {
1320                excerpt: cursor_excerpt.text(&snapshot).body,
1321                excerpt_path: current_file_path,
1322                excerpt_line_range: cursor_excerpt.line_range,
1323                cursor_file_max_row: Line(snapshot.max_point().row),
1324                events: zeta_project
1325                    .events
1326                    .iter()
1327                    .filter_map(|ev| ev.to_request_event(cx))
1328                    .collect(),
1329            },
1330        ) {
1331            Ok(prompt) => prompt,
1332            Err(err) => {
1333                return Task::ready(Err(err));
1334            }
1335        };
1336
1337        if let Some(debug_tx) = &debug_tx {
1338            debug_tx
1339                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1340                    ZetaContextRetrievalStartedDebugInfo {
1341                        project: project.clone(),
1342                        timestamp: Instant::now(),
1343                        search_prompt: prompt.clone(),
1344                    },
1345                ))
1346                .ok();
1347        }
1348
1349        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1350            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1351                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1352            );
1353
1354            let description = schema
1355                .get("description")
1356                .and_then(|description| description.as_str())
1357                .unwrap()
1358                .to_string();
1359
1360            (schema.into(), description)
1361        });
1362
1363        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1364
1365        let request = open_ai::Request {
1366            model: MODEL_ID.clone(),
1367            messages: vec![open_ai::RequestMessage::User {
1368                content: open_ai::MessageContent::Plain(prompt),
1369            }],
1370            stream: false,
1371            max_completion_tokens: None,
1372            stop: Default::default(),
1373            temperature: 0.7,
1374            tool_choice: None,
1375            parallel_tool_calls: None,
1376            tools: vec![open_ai::ToolDefinition::Function {
1377                function: FunctionDefinition {
1378                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1379                    description: Some(tool_description),
1380                    parameters: Some(tool_schema),
1381                },
1382            }],
1383            prompt_cache_key: None,
1384            reasoning_effort: None,
1385        };
1386
1387        #[cfg(feature = "eval-support")]
1388        let eval_cache = self.eval_cache.clone();
1389
1390        cx.spawn(async move |this, cx| {
1391            log::trace!("Sending search planning request");
1392            let response = Self::send_raw_llm_request(
1393                request,
1394                client,
1395                llm_token,
1396                app_version,
1397                #[cfg(feature = "eval-support")]
1398                eval_cache.clone(),
1399                #[cfg(feature = "eval-support")]
1400                EvalCacheEntryKind::Context,
1401            )
1402            .await;
1403            let mut response = Self::handle_api_response(&this, response, cx)?;
1404            log::trace!("Got search planning response");
1405
1406            let choice = response
1407                .choices
1408                .pop()
1409                .context("No choices in retrieval response")?;
1410            let open_ai::RequestMessage::Assistant {
1411                content: _,
1412                tool_calls,
1413            } = choice.message
1414            else {
1415                anyhow::bail!("Retrieval response didn't include an assistant message");
1416            };
1417
1418            let mut queries: Vec<SearchToolQuery> = Vec::new();
1419            for tool_call in tool_calls {
1420                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1421                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1422                    log::warn!(
1423                        "Context retrieval response tried to call an unknown tool: {}",
1424                        function.name
1425                    );
1426
1427                    continue;
1428                }
1429
1430                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1431                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1432                queries.extend(input.queries);
1433            }
1434
1435            if let Some(debug_tx) = &debug_tx {
1436                debug_tx
1437                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1438                        ZetaSearchQueryDebugInfo {
1439                            project: project.clone(),
1440                            timestamp: Instant::now(),
1441                            search_queries: queries.clone(),
1442                        },
1443                    ))
1444                    .ok();
1445            }
1446
1447            log::trace!("Running retrieval search: {queries:#?}");
1448
1449            let related_excerpts_result = retrieval_search::run_retrieval_searches(
1450                queries,
1451                project.clone(),
1452                #[cfg(feature = "eval-support")]
1453                eval_cache,
1454                cx,
1455            )
1456            .await;
1457
1458            log::trace!("Search queries executed");
1459
1460            if let Some(debug_tx) = &debug_tx {
1461                debug_tx
1462                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1463                        ZetaContextRetrievalDebugInfo {
1464                            project: project.clone(),
1465                            timestamp: Instant::now(),
1466                        },
1467                    ))
1468                    .ok();
1469            }
1470
1471            this.update(cx, |this, _cx| {
1472                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1473                    return Ok(());
1474                };
1475                zeta_project.refresh_context_task.take();
1476                if let Some(debug_tx) = &this.debug_tx {
1477                    debug_tx
1478                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1479                            ZetaContextRetrievalDebugInfo {
1480                                project,
1481                                timestamp: Instant::now(),
1482                            },
1483                        ))
1484                        .ok();
1485                }
1486                match related_excerpts_result {
1487                    Ok(excerpts) => {
1488                        zeta_project.context = Some(excerpts);
1489                        Ok(())
1490                    }
1491                    Err(error) => Err(error),
1492                }
1493            })?
1494        })
1495    }
1496
1497    pub fn set_context(
1498        &mut self,
1499        project: Entity<Project>,
1500        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1501    ) {
1502        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1503            zeta_project.context = Some(context);
1504        }
1505    }
1506
1507    fn gather_nearby_diagnostics(
1508        cursor_offset: usize,
1509        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1510        snapshot: &BufferSnapshot,
1511        max_diagnostics_bytes: usize,
1512    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1513        // TODO: Could make this more efficient
1514        let mut diagnostic_groups = Vec::new();
1515        for (language_server_id, diagnostics) in diagnostic_sets {
1516            let mut groups = Vec::new();
1517            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1518            diagnostic_groups.extend(
1519                groups
1520                    .into_iter()
1521                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1522            );
1523        }
1524
1525        // sort by proximity to cursor
1526        diagnostic_groups.sort_by_key(|group| {
1527            let range = &group.entries[group.primary_ix].range;
1528            if range.start >= cursor_offset {
1529                range.start - cursor_offset
1530            } else if cursor_offset >= range.end {
1531                cursor_offset - range.end
1532            } else {
1533                (cursor_offset - range.start).min(range.end - cursor_offset)
1534            }
1535        });
1536
1537        let mut results = Vec::new();
1538        let mut diagnostic_groups_truncated = false;
1539        let mut diagnostics_byte_count = 0;
1540        for group in diagnostic_groups {
1541            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1542            diagnostics_byte_count += raw_value.get().len();
1543            if diagnostics_byte_count > max_diagnostics_bytes {
1544                diagnostic_groups_truncated = true;
1545                break;
1546            }
1547            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1548        }
1549
1550        (results, diagnostic_groups_truncated)
1551    }
1552
1553    // TODO: Dedupe with similar code in request_prediction?
1554    pub fn cloud_request_for_zeta_cli(
1555        &mut self,
1556        project: &Entity<Project>,
1557        buffer: &Entity<Buffer>,
1558        position: language::Anchor,
1559        cx: &mut Context<Self>,
1560    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1561        let project_state = self.projects.get(&project.entity_id());
1562
1563        let index_state = project_state.and_then(|state| {
1564            state
1565                .syntax_index
1566                .as_ref()
1567                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1568        });
1569        let options = self.options.clone();
1570        let snapshot = buffer.read(cx).snapshot();
1571        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1572            return Task::ready(Err(anyhow!("No file path for excerpt")));
1573        };
1574        let worktree_snapshots = project
1575            .read(cx)
1576            .worktrees(cx)
1577            .map(|worktree| worktree.read(cx).snapshot())
1578            .collect::<Vec<_>>();
1579
1580        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1581            let mut path = f.worktree.read(cx).absolutize(&f.path);
1582            if path.pop() { Some(path) } else { None }
1583        });
1584
1585        cx.background_spawn(async move {
1586            let index_state = if let Some(index_state) = index_state {
1587                Some(index_state.lock_owned().await)
1588            } else {
1589                None
1590            };
1591
1592            let cursor_point = position.to_point(&snapshot);
1593
1594            let debug_info = true;
1595            EditPredictionContext::gather_context(
1596                cursor_point,
1597                &snapshot,
1598                parent_abs_path.as_deref(),
1599                match &options.context {
1600                    ContextMode::Agentic(_) => {
1601                        // TODO
1602                        panic!("Llm mode not supported in zeta cli yet");
1603                    }
1604                    ContextMode::Syntax(edit_prediction_context_options) => {
1605                        edit_prediction_context_options
1606                    }
1607                },
1608                index_state.as_deref(),
1609            )
1610            .context("Failed to select excerpt")
1611            .map(|context| {
1612                make_syntax_context_cloud_request(
1613                    excerpt_path.into(),
1614                    context,
1615                    // TODO pass everything
1616                    Vec::new(),
1617                    false,
1618                    Vec::new(),
1619                    false,
1620                    None,
1621                    debug_info,
1622                    &worktree_snapshots,
1623                    index_state.as_deref(),
1624                    Some(options.max_prompt_bytes),
1625                    options.prompt_format,
1626                )
1627            })
1628        })
1629    }
1630
1631    pub fn wait_for_initial_indexing(
1632        &mut self,
1633        project: &Entity<Project>,
1634        cx: &mut App,
1635    ) -> Task<Result<()>> {
1636        let zeta_project = self.get_or_init_zeta_project(project, cx);
1637        if let Some(syntax_index) = &zeta_project.syntax_index {
1638            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1639        } else {
1640            Task::ready(Ok(()))
1641        }
1642    }
1643}
1644
1645pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1646    let choice = res.choices.pop()?;
1647    let output_text = match choice.message {
1648        open_ai::RequestMessage::Assistant {
1649            content: Some(open_ai::MessageContent::Plain(content)),
1650            ..
1651        } => content,
1652        open_ai::RequestMessage::Assistant {
1653            content: Some(open_ai::MessageContent::Multipart(mut content)),
1654            ..
1655        } => {
1656            if content.is_empty() {
1657                log::error!("No output from Baseten completion response");
1658                return None;
1659            }
1660
1661            match content.remove(0) {
1662                open_ai::MessagePart::Text { text } => text,
1663                open_ai::MessagePart::Image { .. } => {
1664                    log::error!("Expected text, got an image");
1665                    return None;
1666                }
1667            }
1668        }
1669        _ => {
1670            log::error!("Invalid response message: {:?}", choice.message);
1671            return None;
1672        }
1673    };
1674    Some(output_text)
1675}
1676
1677#[derive(Error, Debug)]
1678#[error(
1679    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1680)]
1681pub struct ZedUpdateRequiredError {
1682    minimum_version: SemanticVersion,
1683}
1684
1685fn make_syntax_context_cloud_request(
1686    excerpt_path: Arc<Path>,
1687    context: EditPredictionContext,
1688    events: Vec<predict_edits_v3::Event>,
1689    can_collect_data: bool,
1690    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1691    diagnostic_groups_truncated: bool,
1692    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1693    debug_info: bool,
1694    worktrees: &Vec<worktree::Snapshot>,
1695    index_state: Option<&SyntaxIndexState>,
1696    prompt_max_bytes: Option<usize>,
1697    prompt_format: PromptFormat,
1698) -> predict_edits_v3::PredictEditsRequest {
1699    let mut signatures = Vec::new();
1700    let mut declaration_to_signature_index = HashMap::default();
1701    let mut referenced_declarations = Vec::new();
1702
1703    for snippet in context.declarations {
1704        let project_entry_id = snippet.declaration.project_entry_id();
1705        let Some(path) = worktrees.iter().find_map(|worktree| {
1706            worktree.entry_for_id(project_entry_id).map(|entry| {
1707                let mut full_path = RelPathBuf::new();
1708                full_path.push(worktree.root_name());
1709                full_path.push(&entry.path);
1710                full_path
1711            })
1712        }) else {
1713            continue;
1714        };
1715
1716        let parent_index = index_state.and_then(|index_state| {
1717            snippet.declaration.parent().and_then(|parent| {
1718                add_signature(
1719                    parent,
1720                    &mut declaration_to_signature_index,
1721                    &mut signatures,
1722                    index_state,
1723                )
1724            })
1725        });
1726
1727        let (text, text_is_truncated) = snippet.declaration.item_text();
1728        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1729            path: path.as_std_path().into(),
1730            text: text.into(),
1731            range: snippet.declaration.item_line_range(),
1732            text_is_truncated,
1733            signature_range: snippet.declaration.signature_range_in_item_text(),
1734            parent_index,
1735            signature_score: snippet.score(DeclarationStyle::Signature),
1736            declaration_score: snippet.score(DeclarationStyle::Declaration),
1737            score_components: snippet.components,
1738        });
1739    }
1740
1741    let excerpt_parent = index_state.and_then(|index_state| {
1742        context
1743            .excerpt
1744            .parent_declarations
1745            .last()
1746            .and_then(|(parent, _)| {
1747                add_signature(
1748                    *parent,
1749                    &mut declaration_to_signature_index,
1750                    &mut signatures,
1751                    index_state,
1752                )
1753            })
1754    });
1755
1756    predict_edits_v3::PredictEditsRequest {
1757        excerpt_path,
1758        excerpt: context.excerpt_text.body,
1759        excerpt_line_range: context.excerpt.line_range,
1760        excerpt_range: context.excerpt.range,
1761        cursor_point: predict_edits_v3::Point {
1762            line: predict_edits_v3::Line(context.cursor_point.row),
1763            column: context.cursor_point.column,
1764        },
1765        referenced_declarations,
1766        included_files: vec![],
1767        signatures,
1768        excerpt_parent,
1769        events,
1770        can_collect_data,
1771        diagnostic_groups,
1772        diagnostic_groups_truncated,
1773        git_info,
1774        debug_info,
1775        prompt_max_bytes,
1776        prompt_format,
1777    }
1778}
1779
1780fn add_signature(
1781    declaration_id: DeclarationId,
1782    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1783    signatures: &mut Vec<Signature>,
1784    index: &SyntaxIndexState,
1785) -> Option<usize> {
1786    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1787        return Some(*signature_index);
1788    }
1789    let Some(parent_declaration) = index.declaration(declaration_id) else {
1790        log::error!("bug: missing parent declaration");
1791        return None;
1792    };
1793    let parent_index = parent_declaration.parent().and_then(|parent| {
1794        add_signature(parent, declaration_to_signature_index, signatures, index)
1795    });
1796    let (text, text_is_truncated) = parent_declaration.signature_text();
1797    let signature_index = signatures.len();
1798    signatures.push(Signature {
1799        text: text.into(),
1800        text_is_truncated,
1801        parent_index,
1802        range: parent_declaration.signature_line_range(),
1803    });
1804    declaration_to_signature_index.insert(declaration_id, signature_index);
1805    Some(signature_index)
1806}
1807
1808#[cfg(feature = "eval-support")]
1809pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1810
1811#[cfg(feature = "eval-support")]
1812#[derive(Debug, Clone, Copy, PartialEq)]
1813pub enum EvalCacheEntryKind {
1814    Context,
1815    Search,
1816    Prediction,
1817}
1818
1819#[cfg(feature = "eval-support")]
1820impl std::fmt::Display for EvalCacheEntryKind {
1821    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1822        match self {
1823            EvalCacheEntryKind::Search => write!(f, "search"),
1824            EvalCacheEntryKind::Context => write!(f, "context"),
1825            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1826        }
1827    }
1828}
1829
1830#[cfg(feature = "eval-support")]
1831pub trait EvalCache: Send + Sync {
1832    fn read(&self, key: EvalCacheKey) -> Option<String>;
1833    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1834}
1835
1836#[cfg(test)]
1837mod tests {
1838    use std::{path::Path, sync::Arc};
1839
1840    use client::UserStore;
1841    use clock::FakeSystemClock;
1842    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1843    use futures::{
1844        AsyncReadExt, StreamExt,
1845        channel::{mpsc, oneshot},
1846    };
1847    use gpui::{
1848        Entity, TestAppContext,
1849        http_client::{FakeHttpClient, Response},
1850        prelude::*,
1851    };
1852    use indoc::indoc;
1853    use language::OffsetRangeExt as _;
1854    use open_ai::Usage;
1855    use pretty_assertions::{assert_eq, assert_matches};
1856    use project::{FakeFs, Project};
1857    use serde_json::json;
1858    use settings::SettingsStore;
1859    use util::path;
1860    use uuid::Uuid;
1861
1862    use crate::{BufferEditPrediction, Zeta};
1863
1864    #[gpui::test]
1865    async fn test_current_state(cx: &mut TestAppContext) {
1866        let (zeta, mut req_rx) = init_test(cx);
1867        let fs = FakeFs::new(cx.executor());
1868        fs.insert_tree(
1869            "/root",
1870            json!({
1871                "1.txt": "Hello!\nHow\nBye\n",
1872                "2.txt": "Hola!\nComo\nAdios\n"
1873            }),
1874        )
1875        .await;
1876        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1877
1878        zeta.update(cx, |zeta, cx| {
1879            zeta.register_project(&project, cx);
1880        });
1881
1882        let buffer1 = project
1883            .update(cx, |project, cx| {
1884                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1885                project.open_buffer(path, cx)
1886            })
1887            .await
1888            .unwrap();
1889        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1890        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1891
1892        // Prediction for current file
1893
1894        let prediction_task = zeta.update(cx, |zeta, cx| {
1895            zeta.refresh_prediction(&project, &buffer1, position, cx)
1896        });
1897        let (_request, respond_tx) = req_rx.next().await.unwrap();
1898
1899        respond_tx
1900            .send(model_response(indoc! {r"
1901                --- a/root/1.txt
1902                +++ b/root/1.txt
1903                @@ ... @@
1904                 Hello!
1905                -How
1906                +How are you?
1907                 Bye
1908            "}))
1909            .unwrap();
1910        prediction_task.await.unwrap();
1911
1912        zeta.read_with(cx, |zeta, cx| {
1913            let prediction = zeta
1914                .current_prediction_for_buffer(&buffer1, &project, cx)
1915                .unwrap();
1916            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1917        });
1918
1919        // Context refresh
1920        let refresh_task = zeta.update(cx, |zeta, cx| {
1921            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1922        });
1923        let (_request, respond_tx) = req_rx.next().await.unwrap();
1924        respond_tx
1925            .send(open_ai::Response {
1926                id: Uuid::new_v4().to_string(),
1927                object: "response".into(),
1928                created: 0,
1929                model: "model".into(),
1930                choices: vec![open_ai::Choice {
1931                    index: 0,
1932                    message: open_ai::RequestMessage::Assistant {
1933                        content: None,
1934                        tool_calls: vec![open_ai::ToolCall {
1935                            id: "search".into(),
1936                            content: open_ai::ToolCallContent::Function {
1937                                function: open_ai::FunctionContent {
1938                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1939                                        .to_string(),
1940                                    arguments: serde_json::to_string(&SearchToolInput {
1941                                        queries: Box::new([SearchToolQuery {
1942                                            glob: "root/2.txt".to_string(),
1943                                            syntax_node: vec![],
1944                                            content: Some(".".into()),
1945                                        }]),
1946                                    })
1947                                    .unwrap(),
1948                                },
1949                            },
1950                        }],
1951                    },
1952                    finish_reason: None,
1953                }],
1954                usage: Usage {
1955                    prompt_tokens: 0,
1956                    completion_tokens: 0,
1957                    total_tokens: 0,
1958                },
1959            })
1960            .unwrap();
1961        refresh_task.await.unwrap();
1962
1963        zeta.update(cx, |zeta, _cx| {
1964            zeta.discard_current_prediction(&project);
1965        });
1966
1967        // Prediction for another file
1968        let prediction_task = zeta.update(cx, |zeta, cx| {
1969            zeta.refresh_prediction(&project, &buffer1, position, cx)
1970        });
1971        let (_request, respond_tx) = req_rx.next().await.unwrap();
1972        respond_tx
1973            .send(model_response(indoc! {r#"
1974                --- a/root/2.txt
1975                +++ b/root/2.txt
1976                 Hola!
1977                -Como
1978                +Como estas?
1979                 Adios
1980            "#}))
1981            .unwrap();
1982        prediction_task.await.unwrap();
1983        zeta.read_with(cx, |zeta, cx| {
1984            let prediction = zeta
1985                .current_prediction_for_buffer(&buffer1, &project, cx)
1986                .unwrap();
1987            assert_matches!(
1988                prediction,
1989                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1990            );
1991        });
1992
1993        let buffer2 = project
1994            .update(cx, |project, cx| {
1995                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1996                project.open_buffer(path, cx)
1997            })
1998            .await
1999            .unwrap();
2000
2001        zeta.read_with(cx, |zeta, cx| {
2002            let prediction = zeta
2003                .current_prediction_for_buffer(&buffer2, &project, cx)
2004                .unwrap();
2005            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2006        });
2007    }
2008
2009    #[gpui::test]
2010    async fn test_simple_request(cx: &mut TestAppContext) {
2011        let (zeta, mut req_rx) = init_test(cx);
2012        let fs = FakeFs::new(cx.executor());
2013        fs.insert_tree(
2014            "/root",
2015            json!({
2016                "foo.md":  "Hello!\nHow\nBye\n"
2017            }),
2018        )
2019        .await;
2020        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2021
2022        let buffer = project
2023            .update(cx, |project, cx| {
2024                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2025                project.open_buffer(path, cx)
2026            })
2027            .await
2028            .unwrap();
2029        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2030        let position = snapshot.anchor_before(language::Point::new(1, 3));
2031
2032        let prediction_task = zeta.update(cx, |zeta, cx| {
2033            zeta.request_prediction(&project, &buffer, position, cx)
2034        });
2035
2036        let (_, respond_tx) = req_rx.next().await.unwrap();
2037
2038        // TODO Put back when we have a structured request again
2039        // assert_eq!(
2040        //     request.excerpt_path.as_ref(),
2041        //     Path::new(path!("root/foo.md"))
2042        // );
2043        // assert_eq!(
2044        //     request.cursor_point,
2045        //     Point {
2046        //         line: Line(1),
2047        //         column: 3
2048        //     }
2049        // );
2050
2051        respond_tx
2052            .send(model_response(indoc! { r"
2053                --- a/root/foo.md
2054                +++ b/root/foo.md
2055                @@ ... @@
2056                 Hello!
2057                -How
2058                +How are you?
2059                 Bye
2060            "}))
2061            .unwrap();
2062
2063        let prediction = prediction_task.await.unwrap().unwrap();
2064
2065        assert_eq!(prediction.edits.len(), 1);
2066        assert_eq!(
2067            prediction.edits[0].0.to_point(&snapshot).start,
2068            language::Point::new(1, 3)
2069        );
2070        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2071    }
2072
2073    #[gpui::test]
2074    async fn test_request_events(cx: &mut TestAppContext) {
2075        let (zeta, mut req_rx) = init_test(cx);
2076        let fs = FakeFs::new(cx.executor());
2077        fs.insert_tree(
2078            "/root",
2079            json!({
2080                "foo.md": "Hello!\n\nBye\n"
2081            }),
2082        )
2083        .await;
2084        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2085
2086        let buffer = project
2087            .update(cx, |project, cx| {
2088                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2089                project.open_buffer(path, cx)
2090            })
2091            .await
2092            .unwrap();
2093
2094        zeta.update(cx, |zeta, cx| {
2095            zeta.register_buffer(&buffer, &project, cx);
2096        });
2097
2098        buffer.update(cx, |buffer, cx| {
2099            buffer.edit(vec![(7..7, "How")], None, cx);
2100        });
2101
2102        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2103        let position = snapshot.anchor_before(language::Point::new(1, 3));
2104
2105        let prediction_task = zeta.update(cx, |zeta, cx| {
2106            zeta.request_prediction(&project, &buffer, position, cx)
2107        });
2108
2109        let (request, respond_tx) = req_rx.next().await.unwrap();
2110
2111        let prompt = prompt_from_request(&request);
2112        assert!(
2113            prompt.contains(indoc! {"
2114            --- a/root/foo.md
2115            +++ b/root/foo.md
2116            @@ -1,3 +1,3 @@
2117             Hello!
2118            -
2119            +How
2120             Bye
2121        "}),
2122            "{prompt}"
2123        );
2124
2125        respond_tx
2126            .send(model_response(indoc! {r#"
2127                --- a/root/foo.md
2128                +++ b/root/foo.md
2129                @@ ... @@
2130                 Hello!
2131                -How
2132                +How are you?
2133                 Bye
2134            "#}))
2135            .unwrap();
2136
2137        let prediction = prediction_task.await.unwrap().unwrap();
2138
2139        assert_eq!(prediction.edits.len(), 1);
2140        assert_eq!(
2141            prediction.edits[0].0.to_point(&snapshot).start,
2142            language::Point::new(1, 3)
2143        );
2144        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2145    }
2146
2147    // Skipped until we start including diagnostics in prompt
2148    // #[gpui::test]
2149    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2150    //     let (zeta, mut req_rx) = init_test(cx);
2151    //     let fs = FakeFs::new(cx.executor());
2152    //     fs.insert_tree(
2153    //         "/root",
2154    //         json!({
2155    //             "foo.md": "Hello!\nBye"
2156    //         }),
2157    //     )
2158    //     .await;
2159    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2160
2161    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2162    //     let diagnostic = lsp::Diagnostic {
2163    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2164    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2165    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2166    //         ..Default::default()
2167    //     };
2168
2169    //     project.update(cx, |project, cx| {
2170    //         project.lsp_store().update(cx, |lsp_store, cx| {
2171    //             // Create some diagnostics
2172    //             lsp_store
2173    //                 .update_diagnostics(
2174    //                     LanguageServerId(0),
2175    //                     lsp::PublishDiagnosticsParams {
2176    //                         uri: path_to_buffer_uri.clone(),
2177    //                         diagnostics: vec![diagnostic],
2178    //                         version: None,
2179    //                     },
2180    //                     None,
2181    //                     language::DiagnosticSourceKind::Pushed,
2182    //                     &[],
2183    //                     cx,
2184    //                 )
2185    //                 .unwrap();
2186    //         });
2187    //     });
2188
2189    //     let buffer = project
2190    //         .update(cx, |project, cx| {
2191    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2192    //             project.open_buffer(path, cx)
2193    //         })
2194    //         .await
2195    //         .unwrap();
2196
2197    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2198    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2199
2200    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2201    //         zeta.request_prediction(&project, &buffer, position, cx)
2202    //     });
2203
2204    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2205
2206    //     assert_eq!(request.diagnostic_groups.len(), 1);
2207    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2208    //         .unwrap();
2209    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2210    //     assert_eq!(
2211    //         value,
2212    //         json!({
2213    //             "entries": [{
2214    //                 "range": {
2215    //                     "start": 8,
2216    //                     "end": 10
2217    //                 },
2218    //                 "diagnostic": {
2219    //                     "source": null,
2220    //                     "code": null,
2221    //                     "code_description": null,
2222    //                     "severity": 1,
2223    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2224    //                     "markdown": null,
2225    //                     "group_id": 0,
2226    //                     "is_primary": true,
2227    //                     "is_disk_based": false,
2228    //                     "is_unnecessary": false,
2229    //                     "source_kind": "Pushed",
2230    //                     "data": null,
2231    //                     "underline": true
2232    //                 }
2233    //             }],
2234    //             "primary_ix": 0
2235    //         })
2236    //     );
2237    // }
2238
2239    fn model_response(text: &str) -> open_ai::Response {
2240        open_ai::Response {
2241            id: Uuid::new_v4().to_string(),
2242            object: "response".into(),
2243            created: 0,
2244            model: "model".into(),
2245            choices: vec![open_ai::Choice {
2246                index: 0,
2247                message: open_ai::RequestMessage::Assistant {
2248                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2249                    tool_calls: vec![],
2250                },
2251                finish_reason: None,
2252            }],
2253            usage: Usage {
2254                prompt_tokens: 0,
2255                completion_tokens: 0,
2256                total_tokens: 0,
2257            },
2258        }
2259    }
2260
2261    fn prompt_from_request(request: &open_ai::Request) -> &str {
2262        assert_eq!(request.messages.len(), 1);
2263        let open_ai::RequestMessage::User {
2264            content: open_ai::MessageContent::Plain(content),
2265            ..
2266        } = &request.messages[0]
2267        else {
2268            panic!(
2269                "Request does not have single user message of type Plain. {:#?}",
2270                request
2271            );
2272        };
2273        content
2274    }
2275
2276    fn init_test(
2277        cx: &mut TestAppContext,
2278    ) -> (
2279        Entity<Zeta>,
2280        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2281    ) {
2282        cx.update(move |cx| {
2283            let settings_store = SettingsStore::test(cx);
2284            cx.set_global(settings_store);
2285            zlog::init_test();
2286
2287            let (req_tx, req_rx) = mpsc::unbounded();
2288
2289            let http_client = FakeHttpClient::create({
2290                move |req| {
2291                    let uri = req.uri().path().to_string();
2292                    let mut body = req.into_body();
2293                    let req_tx = req_tx.clone();
2294                    async move {
2295                        let resp = match uri.as_str() {
2296                            "/client/llm_tokens" => serde_json::to_string(&json!({
2297                                "token": "test"
2298                            }))
2299                            .unwrap(),
2300                            "/predict_edits/raw" => {
2301                                let mut buf = Vec::new();
2302                                body.read_to_end(&mut buf).await.ok();
2303                                let req = serde_json::from_slice(&buf).unwrap();
2304
2305                                let (res_tx, res_rx) = oneshot::channel();
2306                                req_tx.unbounded_send((req, res_tx)).unwrap();
2307                                serde_json::to_string(&res_rx.await?).unwrap()
2308                            }
2309                            _ => {
2310                                panic!("Unexpected path: {}", uri)
2311                            }
2312                        };
2313
2314                        Ok(Response::builder().body(resp.into()).unwrap())
2315                    }
2316                }
2317            });
2318
2319            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2320            client.cloud_client().set_credentials(1, "test".into());
2321
2322            language_model::init(client.clone(), cx);
2323
2324            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2325            let zeta = Zeta::global(&client, &user_store, cx);
2326
2327            (zeta, req_rx)
2328        })
2329    }
2330}