zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50
  51use crate::merge_excerpts::merge_excerpts;
  52use crate::prediction::EditPrediction;
  53pub use crate::prediction::EditPredictionId;
  54pub use provider::ZetaEditPredictionProvider;
  55
  56/// Maximum number of events to track.
  57const MAX_EVENT_COUNT: usize = 16;
  58
  59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  60    max_bytes: 512,
  61    min_bytes: 128,
  62    target_before_cursor_over_total_bytes: 0.5,
  63};
  64
  65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  66    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  67
  68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  69    excerpt: DEFAULT_EXCERPT_OPTIONS,
  70};
  71
  72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  73    EditPredictionContextOptions {
  74        use_imports: true,
  75        max_retrieved_declarations: 0,
  76        excerpt: DEFAULT_EXCERPT_OPTIONS,
  77        score: EditPredictionScoreOptions {
  78            omit_excerpt_overlaps: true,
  79        },
  80    };
  81
  82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  83    context: DEFAULT_CONTEXT_OPTIONS,
  84    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  85    max_diagnostic_bytes: 2048,
  86    prompt_format: PromptFormat::DEFAULT,
  87    file_indexing_parallelism: 1,
  88    buffer_change_grouping_interval: Duration::from_secs(1),
  89};
  90
  91static USE_OLLAMA: LazyLock<bool> =
  92    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  93static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  94    env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
  95        "qwen3-coder:30b".to_string()
  96    } else {
  97        "yqvev8r3".to_string()
  98    })
  99});
 100static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 101    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 102        if *USE_OLLAMA {
 103            Some("http://localhost:11434/v1/chat/completions".into())
 104        } else {
 105            None
 106        }
 107    })
 108});
 109
 110pub struct Zeta2FeatureFlag;
 111
 112impl FeatureFlag for Zeta2FeatureFlag {
 113    const NAME: &'static str = "zeta2";
 114
 115    fn enabled_for_staff() -> bool {
 116        false
 117    }
 118}
 119
 120#[derive(Clone)]
 121struct ZetaGlobal(Entity<Zeta>);
 122
 123impl Global for ZetaGlobal {}
 124
 125pub struct Zeta {
 126    client: Arc<Client>,
 127    user_store: Entity<UserStore>,
 128    llm_token: LlmApiToken,
 129    _llm_token_subscription: Subscription,
 130    projects: HashMap<EntityId, ZetaProject>,
 131    options: ZetaOptions,
 132    update_required: bool,
 133    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 134    #[cfg(feature = "llm-response-cache")]
 135    llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
 136}
 137
 138#[cfg(feature = "llm-response-cache")]
 139pub trait LlmResponseCache: Send + Sync {
 140    fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
 141    fn read_response(&self, key: u64) -> Option<String>;
 142    fn write_response(&self, key: u64, value: &str);
 143}
 144
 145#[derive(Debug, Clone, PartialEq)]
 146pub struct ZetaOptions {
 147    pub context: ContextMode,
 148    pub max_prompt_bytes: usize,
 149    pub max_diagnostic_bytes: usize,
 150    pub prompt_format: predict_edits_v3::PromptFormat,
 151    pub file_indexing_parallelism: usize,
 152    pub buffer_change_grouping_interval: Duration,
 153}
 154
 155#[derive(Debug, Clone, PartialEq)]
 156pub enum ContextMode {
 157    Agentic(AgenticContextOptions),
 158    Syntax(EditPredictionContextOptions),
 159}
 160
 161#[derive(Debug, Clone, PartialEq)]
 162pub struct AgenticContextOptions {
 163    pub excerpt: EditPredictionExcerptOptions,
 164}
 165
 166impl ContextMode {
 167    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 168        match self {
 169            ContextMode::Agentic(options) => &options.excerpt,
 170            ContextMode::Syntax(options) => &options.excerpt,
 171        }
 172    }
 173}
 174
 175#[derive(Debug)]
 176pub enum ZetaDebugInfo {
 177    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 178    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 179    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 180    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 181    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 182}
 183
 184#[derive(Debug)]
 185pub struct ZetaContextRetrievalStartedDebugInfo {
 186    pub project: Entity<Project>,
 187    pub timestamp: Instant,
 188    pub search_prompt: String,
 189}
 190
 191#[derive(Debug)]
 192pub struct ZetaContextRetrievalDebugInfo {
 193    pub project: Entity<Project>,
 194    pub timestamp: Instant,
 195}
 196
 197#[derive(Debug)]
 198pub struct ZetaEditPredictionDebugInfo {
 199    pub request: predict_edits_v3::PredictEditsRequest,
 200    pub retrieval_time: TimeDelta,
 201    pub buffer: WeakEntity<Buffer>,
 202    pub position: language::Anchor,
 203    pub local_prompt: Result<String, String>,
 204    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 205}
 206
 207#[derive(Debug)]
 208pub struct ZetaSearchQueryDebugInfo {
 209    pub project: Entity<Project>,
 210    pub timestamp: Instant,
 211    pub search_queries: Vec<SearchToolQuery>,
 212}
 213
 214pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 215
 216struct ZetaProject {
 217    syntax_index: Entity<SyntaxIndex>,
 218    events: VecDeque<Event>,
 219    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 220    current_prediction: Option<CurrentEditPrediction>,
 221    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 222    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 223    refresh_context_debounce_task: Option<Task<Option<()>>>,
 224    refresh_context_timestamp: Option<Instant>,
 225}
 226
 227#[derive(Debug, Clone)]
 228struct CurrentEditPrediction {
 229    pub requested_by_buffer_id: EntityId,
 230    pub prediction: EditPrediction,
 231}
 232
 233impl CurrentEditPrediction {
 234    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 235        let Some(new_edits) = self
 236            .prediction
 237            .interpolate(&self.prediction.buffer.read(cx))
 238        else {
 239            return false;
 240        };
 241
 242        if self.prediction.buffer != old_prediction.prediction.buffer {
 243            return true;
 244        }
 245
 246        let Some(old_edits) = old_prediction
 247            .prediction
 248            .interpolate(&old_prediction.prediction.buffer.read(cx))
 249        else {
 250            return true;
 251        };
 252
 253        // This reduces the occurrence of UI thrash from replacing edits
 254        //
 255        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 256        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 257            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 258            && old_edits.len() == 1
 259            && new_edits.len() == 1
 260        {
 261            let (old_range, old_text) = &old_edits[0];
 262            let (new_range, new_text) = &new_edits[0];
 263            new_range == old_range && new_text.starts_with(old_text.as_ref())
 264        } else {
 265            true
 266        }
 267    }
 268}
 269
 270/// A prediction from the perspective of a buffer.
 271#[derive(Debug)]
 272enum BufferEditPrediction<'a> {
 273    Local { prediction: &'a EditPrediction },
 274    Jump { prediction: &'a EditPrediction },
 275}
 276
 277struct RegisteredBuffer {
 278    snapshot: BufferSnapshot,
 279    _subscriptions: [gpui::Subscription; 2],
 280}
 281
 282#[derive(Clone)]
 283pub enum Event {
 284    BufferChange {
 285        old_snapshot: BufferSnapshot,
 286        new_snapshot: BufferSnapshot,
 287        timestamp: Instant,
 288    },
 289}
 290
 291impl Event {
 292    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 293        match self {
 294            Event::BufferChange {
 295                old_snapshot,
 296                new_snapshot,
 297                ..
 298            } => {
 299                let path = new_snapshot.file().map(|f| f.full_path(cx));
 300
 301                let old_path = old_snapshot.file().and_then(|f| {
 302                    let old_path = f.full_path(cx);
 303                    if Some(&old_path) != path.as_ref() {
 304                        Some(old_path)
 305                    } else {
 306                        None
 307                    }
 308                });
 309
 310                // TODO [zeta2] move to bg?
 311                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 312
 313                if path == old_path && diff.is_empty() {
 314                    None
 315                } else {
 316                    Some(predict_edits_v3::Event::BufferChange {
 317                        old_path,
 318                        path,
 319                        diff,
 320                        //todo: Actually detect if this edit was predicted or not
 321                        predicted: false,
 322                    })
 323                }
 324            }
 325        }
 326    }
 327}
 328
 329impl Zeta {
 330    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 331        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 332    }
 333
 334    pub fn global(
 335        client: &Arc<Client>,
 336        user_store: &Entity<UserStore>,
 337        cx: &mut App,
 338    ) -> Entity<Self> {
 339        cx.try_global::<ZetaGlobal>()
 340            .map(|global| global.0.clone())
 341            .unwrap_or_else(|| {
 342                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 343                cx.set_global(ZetaGlobal(zeta.clone()));
 344                zeta
 345            })
 346    }
 347
 348    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 349        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 350
 351        Self {
 352            projects: HashMap::default(),
 353            client,
 354            user_store,
 355            options: DEFAULT_OPTIONS,
 356            llm_token: LlmApiToken::default(),
 357            _llm_token_subscription: cx.subscribe(
 358                &refresh_llm_token_listener,
 359                |this, _listener, _event, cx| {
 360                    let client = this.client.clone();
 361                    let llm_token = this.llm_token.clone();
 362                    cx.spawn(async move |_this, _cx| {
 363                        llm_token.refresh(&client).await?;
 364                        anyhow::Ok(())
 365                    })
 366                    .detach_and_log_err(cx);
 367                },
 368            ),
 369            update_required: false,
 370            debug_tx: None,
 371            #[cfg(feature = "llm-response-cache")]
 372            llm_response_cache: None,
 373        }
 374    }
 375
 376    #[cfg(feature = "llm-response-cache")]
 377    pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
 378        self.llm_response_cache = Some(cache);
 379    }
 380
 381    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 382        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 383        self.debug_tx = Some(debug_watch_tx);
 384        debug_watch_rx
 385    }
 386
 387    pub fn options(&self) -> &ZetaOptions {
 388        &self.options
 389    }
 390
 391    pub fn set_options(&mut self, options: ZetaOptions) {
 392        self.options = options;
 393    }
 394
 395    pub fn clear_history(&mut self) {
 396        for zeta_project in self.projects.values_mut() {
 397            zeta_project.events.clear();
 398        }
 399    }
 400
 401    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 402        self.projects
 403            .get(&project.entity_id())
 404            .map(|project| project.events.iter())
 405            .into_iter()
 406            .flatten()
 407    }
 408
 409    pub fn context_for_project(
 410        &self,
 411        project: &Entity<Project>,
 412    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 413        self.projects
 414            .get(&project.entity_id())
 415            .and_then(|project| {
 416                Some(
 417                    project
 418                        .context
 419                        .as_ref()?
 420                        .iter()
 421                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 422                )
 423            })
 424            .into_iter()
 425            .flatten()
 426    }
 427
 428    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 429        self.user_store.read(cx).edit_prediction_usage()
 430    }
 431
 432    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 433        self.get_or_init_zeta_project(project, cx);
 434    }
 435
 436    pub fn register_buffer(
 437        &mut self,
 438        buffer: &Entity<Buffer>,
 439        project: &Entity<Project>,
 440        cx: &mut Context<Self>,
 441    ) {
 442        let zeta_project = self.get_or_init_zeta_project(project, cx);
 443        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 444    }
 445
 446    fn get_or_init_zeta_project(
 447        &mut self,
 448        project: &Entity<Project>,
 449        cx: &mut App,
 450    ) -> &mut ZetaProject {
 451        self.projects
 452            .entry(project.entity_id())
 453            .or_insert_with(|| ZetaProject {
 454                syntax_index: cx.new(|cx| {
 455                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 456                }),
 457                events: VecDeque::new(),
 458                registered_buffers: HashMap::default(),
 459                current_prediction: None,
 460                context: None,
 461                refresh_context_task: None,
 462                refresh_context_debounce_task: None,
 463                refresh_context_timestamp: None,
 464            })
 465    }
 466
 467    fn register_buffer_impl<'a>(
 468        zeta_project: &'a mut ZetaProject,
 469        buffer: &Entity<Buffer>,
 470        project: &Entity<Project>,
 471        cx: &mut Context<Self>,
 472    ) -> &'a mut RegisteredBuffer {
 473        let buffer_id = buffer.entity_id();
 474        match zeta_project.registered_buffers.entry(buffer_id) {
 475            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 476            hash_map::Entry::Vacant(entry) => {
 477                let snapshot = buffer.read(cx).snapshot();
 478                let project_entity_id = project.entity_id();
 479                entry.insert(RegisteredBuffer {
 480                    snapshot,
 481                    _subscriptions: [
 482                        cx.subscribe(buffer, {
 483                            let project = project.downgrade();
 484                            move |this, buffer, event, cx| {
 485                                if let language::BufferEvent::Edited = event
 486                                    && let Some(project) = project.upgrade()
 487                                {
 488                                    this.report_changes_for_buffer(&buffer, &project, cx);
 489                                }
 490                            }
 491                        }),
 492                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 493                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 494                            else {
 495                                return;
 496                            };
 497                            zeta_project.registered_buffers.remove(&buffer_id);
 498                        }),
 499                    ],
 500                })
 501            }
 502        }
 503    }
 504
 505    fn report_changes_for_buffer(
 506        &mut self,
 507        buffer: &Entity<Buffer>,
 508        project: &Entity<Project>,
 509        cx: &mut Context<Self>,
 510    ) -> BufferSnapshot {
 511        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 512        let zeta_project = self.get_or_init_zeta_project(project, cx);
 513        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 514
 515        let new_snapshot = buffer.read(cx).snapshot();
 516        if new_snapshot.version != registered_buffer.snapshot.version {
 517            let old_snapshot =
 518                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 519            Self::push_event(
 520                zeta_project,
 521                buffer_change_grouping_interval,
 522                Event::BufferChange {
 523                    old_snapshot,
 524                    new_snapshot: new_snapshot.clone(),
 525                    timestamp: Instant::now(),
 526                },
 527            );
 528        }
 529
 530        new_snapshot
 531    }
 532
 533    fn push_event(
 534        zeta_project: &mut ZetaProject,
 535        buffer_change_grouping_interval: Duration,
 536        event: Event,
 537    ) {
 538        let events = &mut zeta_project.events;
 539
 540        if buffer_change_grouping_interval > Duration::ZERO
 541            && let Some(Event::BufferChange {
 542                new_snapshot: last_new_snapshot,
 543                timestamp: last_timestamp,
 544                ..
 545            }) = events.back_mut()
 546        {
 547            // Coalesce edits for the same buffer when they happen one after the other.
 548            let Event::BufferChange {
 549                old_snapshot,
 550                new_snapshot,
 551                timestamp,
 552            } = &event;
 553
 554            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 555                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 556                && old_snapshot.version == last_new_snapshot.version
 557            {
 558                *last_new_snapshot = new_snapshot.clone();
 559                *last_timestamp = *timestamp;
 560                return;
 561            }
 562        }
 563
 564        if events.len() >= MAX_EVENT_COUNT {
 565            // These are halved instead of popping to improve prompt caching.
 566            events.drain(..MAX_EVENT_COUNT / 2);
 567        }
 568
 569        events.push_back(event);
 570    }
 571
 572    fn current_prediction_for_buffer(
 573        &self,
 574        buffer: &Entity<Buffer>,
 575        project: &Entity<Project>,
 576        cx: &App,
 577    ) -> Option<BufferEditPrediction<'_>> {
 578        let project_state = self.projects.get(&project.entity_id())?;
 579
 580        let CurrentEditPrediction {
 581            requested_by_buffer_id,
 582            prediction,
 583        } = project_state.current_prediction.as_ref()?;
 584
 585        if prediction.targets_buffer(buffer.read(cx)) {
 586            Some(BufferEditPrediction::Local { prediction })
 587        } else if *requested_by_buffer_id == buffer.entity_id() {
 588            Some(BufferEditPrediction::Jump { prediction })
 589        } else {
 590            None
 591        }
 592    }
 593
 594    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 595        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 596            return;
 597        };
 598
 599        let Some(prediction) = project_state.current_prediction.take() else {
 600            return;
 601        };
 602        let request_id = prediction.prediction.id.to_string();
 603
 604        let client = self.client.clone();
 605        let llm_token = self.llm_token.clone();
 606        let app_version = AppVersion::global(cx);
 607        cx.spawn(async move |this, cx| {
 608            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 609                http_client::Url::parse(&predict_edits_url)?
 610            } else {
 611                client
 612                    .http_client()
 613                    .build_zed_llm_url("/predict_edits/accept", &[])?
 614            };
 615
 616            let response = cx
 617                .background_spawn(Self::send_api_request::<()>(
 618                    move |builder| {
 619                        let req = builder.uri(url.as_ref()).body(
 620                            serde_json::to_string(&AcceptEditPredictionBody {
 621                                request_id: request_id.clone(),
 622                            })?
 623                            .into(),
 624                        );
 625                        Ok(req?)
 626                    },
 627                    client,
 628                    llm_token,
 629                    app_version,
 630                ))
 631                .await;
 632
 633            Self::handle_api_response(&this, response, cx)?;
 634            anyhow::Ok(())
 635        })
 636        .detach_and_log_err(cx);
 637    }
 638
 639    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 640        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 641            project_state.current_prediction.take();
 642        };
 643    }
 644
 645    pub fn refresh_prediction(
 646        &mut self,
 647        project: &Entity<Project>,
 648        buffer: &Entity<Buffer>,
 649        position: language::Anchor,
 650        cx: &mut Context<Self>,
 651    ) -> Task<Result<()>> {
 652        let request_task = self.request_prediction(project, buffer, position, cx);
 653        let buffer = buffer.clone();
 654        let project = project.clone();
 655
 656        cx.spawn(async move |this, cx| {
 657            if let Some(prediction) = request_task.await? {
 658                this.update(cx, |this, cx| {
 659                    let project_state = this
 660                        .projects
 661                        .get_mut(&project.entity_id())
 662                        .context("Project not found")?;
 663
 664                    let new_prediction = CurrentEditPrediction {
 665                        requested_by_buffer_id: buffer.entity_id(),
 666                        prediction: prediction,
 667                    };
 668
 669                    if project_state
 670                        .current_prediction
 671                        .as_ref()
 672                        .is_none_or(|old_prediction| {
 673                            new_prediction.should_replace_prediction(&old_prediction, cx)
 674                        })
 675                    {
 676                        project_state.current_prediction = Some(new_prediction);
 677                    }
 678                    anyhow::Ok(())
 679                })??;
 680            }
 681            Ok(())
 682        })
 683    }
 684
 685    pub fn request_prediction(
 686        &mut self,
 687        project: &Entity<Project>,
 688        active_buffer: &Entity<Buffer>,
 689        position: language::Anchor,
 690        cx: &mut Context<Self>,
 691    ) -> Task<Result<Option<EditPrediction>>> {
 692        let project_state = self.projects.get(&project.entity_id());
 693
 694        let index_state = project_state.map(|state| {
 695            state
 696                .syntax_index
 697                .read_with(cx, |index, _cx| index.state().clone())
 698        });
 699        let options = self.options.clone();
 700        let active_snapshot = active_buffer.read(cx).snapshot();
 701        let Some(excerpt_path) = active_snapshot
 702            .file()
 703            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 704        else {
 705            return Task::ready(Err(anyhow!("No file path for excerpt")));
 706        };
 707        let client = self.client.clone();
 708        let llm_token = self.llm_token.clone();
 709        let app_version = AppVersion::global(cx);
 710        let worktree_snapshots = project
 711            .read(cx)
 712            .worktrees(cx)
 713            .map(|worktree| worktree.read(cx).snapshot())
 714            .collect::<Vec<_>>();
 715        let debug_tx = self.debug_tx.clone();
 716
 717        let events = project_state
 718            .map(|state| {
 719                state
 720                    .events
 721                    .iter()
 722                    .filter_map(|event| event.to_request_event(cx))
 723                    .collect::<Vec<_>>()
 724            })
 725            .unwrap_or_default();
 726
 727        let diagnostics = active_snapshot.diagnostic_sets().clone();
 728
 729        let parent_abs_path =
 730            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 731                let mut path = f.worktree.read(cx).absolutize(&f.path);
 732                if path.pop() { Some(path) } else { None }
 733            });
 734
 735        // TODO data collection
 736        let can_collect_data = cx.is_staff();
 737
 738        let mut included_files = project_state
 739            .and_then(|project_state| project_state.context.as_ref())
 740            .unwrap_or(&HashMap::default())
 741            .iter()
 742            .filter_map(|(buffer_entity, ranges)| {
 743                let buffer = buffer_entity.read(cx);
 744                Some((
 745                    buffer_entity.clone(),
 746                    buffer.snapshot(),
 747                    buffer.file()?.full_path(cx).into(),
 748                    ranges.clone(),
 749                ))
 750            })
 751            .collect::<Vec<_>>();
 752
 753        #[cfg(feature = "llm-response-cache")]
 754        let llm_response_cache = self.llm_response_cache.clone();
 755
 756        let request_task = cx.background_spawn({
 757            let active_buffer = active_buffer.clone();
 758            async move {
 759                let index_state = if let Some(index_state) = index_state {
 760                    Some(index_state.lock_owned().await)
 761                } else {
 762                    None
 763                };
 764
 765                let cursor_offset = position.to_offset(&active_snapshot);
 766                let cursor_point = cursor_offset.to_point(&active_snapshot);
 767
 768                let before_retrieval = chrono::Utc::now();
 769
 770                let (diagnostic_groups, diagnostic_groups_truncated) =
 771                    Self::gather_nearby_diagnostics(
 772                        cursor_offset,
 773                        &diagnostics,
 774                        &active_snapshot,
 775                        options.max_diagnostic_bytes,
 776                    );
 777
 778                let cloud_request = match options.context {
 779                    ContextMode::Agentic(context_options) => {
 780                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 781                            cursor_point,
 782                            &active_snapshot,
 783                            &context_options.excerpt,
 784                            index_state.as_deref(),
 785                        ) else {
 786                            return Ok((None, None));
 787                        };
 788
 789                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 790                            ..active_snapshot.anchor_before(excerpt.range.end);
 791
 792                        if let Some(buffer_ix) =
 793                            included_files.iter().position(|(_, snapshot, _, _)| {
 794                                snapshot.remote_id() == active_snapshot.remote_id()
 795                            })
 796                        {
 797                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 798                            let range_ix = ranges
 799                                .binary_search_by(|probe| {
 800                                    probe
 801                                        .start
 802                                        .cmp(&excerpt_anchor_range.start, buffer)
 803                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 804                                })
 805                                .unwrap_or_else(|ix| ix);
 806
 807                            ranges.insert(range_ix, excerpt_anchor_range);
 808                            let last_ix = included_files.len() - 1;
 809                            included_files.swap(buffer_ix, last_ix);
 810                        } else {
 811                            included_files.push((
 812                                active_buffer.clone(),
 813                                active_snapshot,
 814                                excerpt_path.clone(),
 815                                vec![excerpt_anchor_range],
 816                            ));
 817                        }
 818
 819                        let included_files = included_files
 820                            .iter()
 821                            .map(|(_, buffer, path, ranges)| {
 822                                let excerpts = merge_excerpts(
 823                                    &buffer,
 824                                    ranges.iter().map(|range| {
 825                                        let point_range = range.to_point(&buffer);
 826                                        Line(point_range.start.row)..Line(point_range.end.row)
 827                                    }),
 828                                );
 829                                predict_edits_v3::IncludedFile {
 830                                    path: path.clone(),
 831                                    max_row: Line(buffer.max_point().row),
 832                                    excerpts,
 833                                }
 834                            })
 835                            .collect::<Vec<_>>();
 836
 837                        predict_edits_v3::PredictEditsRequest {
 838                            excerpt_path,
 839                            excerpt: String::new(),
 840                            excerpt_line_range: Line(0)..Line(0),
 841                            excerpt_range: 0..0,
 842                            cursor_point: predict_edits_v3::Point {
 843                                line: predict_edits_v3::Line(cursor_point.row),
 844                                column: cursor_point.column,
 845                            },
 846                            included_files,
 847                            referenced_declarations: vec![],
 848                            events,
 849                            can_collect_data,
 850                            diagnostic_groups,
 851                            diagnostic_groups_truncated,
 852                            debug_info: debug_tx.is_some(),
 853                            prompt_max_bytes: Some(options.max_prompt_bytes),
 854                            prompt_format: options.prompt_format,
 855                            // TODO [zeta2]
 856                            signatures: vec![],
 857                            excerpt_parent: None,
 858                            git_info: None,
 859                        }
 860                    }
 861                    ContextMode::Syntax(context_options) => {
 862                        let Some(context) = EditPredictionContext::gather_context(
 863                            cursor_point,
 864                            &active_snapshot,
 865                            parent_abs_path.as_deref(),
 866                            &context_options,
 867                            index_state.as_deref(),
 868                        ) else {
 869                            return Ok((None, None));
 870                        };
 871
 872                        make_syntax_context_cloud_request(
 873                            excerpt_path,
 874                            context,
 875                            events,
 876                            can_collect_data,
 877                            diagnostic_groups,
 878                            diagnostic_groups_truncated,
 879                            None,
 880                            debug_tx.is_some(),
 881                            &worktree_snapshots,
 882                            index_state.as_deref(),
 883                            Some(options.max_prompt_bytes),
 884                            options.prompt_format,
 885                        )
 886                    }
 887                };
 888
 889                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 890
 891                let retrieval_time = chrono::Utc::now() - before_retrieval;
 892
 893                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 894                    let (response_tx, response_rx) = oneshot::channel();
 895
 896                    debug_tx
 897                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 898                            ZetaEditPredictionDebugInfo {
 899                                request: cloud_request.clone(),
 900                                retrieval_time,
 901                                buffer: active_buffer.downgrade(),
 902                                local_prompt: match prompt_result.as_ref() {
 903                                    Ok((prompt, _)) => Ok(prompt.clone()),
 904                                    Err(err) => Err(err.to_string()),
 905                                },
 906                                position,
 907                                response_rx,
 908                            },
 909                        ))
 910                        .ok();
 911                    Some(response_tx)
 912                } else {
 913                    None
 914                };
 915
 916                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 917                    if let Some(debug_response_tx) = debug_response_tx {
 918                        debug_response_tx
 919                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 920                            .ok();
 921                    }
 922                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 923                }
 924
 925                let (prompt, _) = prompt_result?;
 926                let request = open_ai::Request {
 927                    model: MODEL_ID.clone(),
 928                    messages: vec![open_ai::RequestMessage::User {
 929                        content: open_ai::MessageContent::Plain(prompt),
 930                    }],
 931                    stream: false,
 932                    max_completion_tokens: None,
 933                    stop: Default::default(),
 934                    temperature: 0.7,
 935                    tool_choice: None,
 936                    parallel_tool_calls: None,
 937                    tools: vec![],
 938                    prompt_cache_key: None,
 939                    reasoning_effort: None,
 940                };
 941
 942                log::trace!("Sending edit prediction request");
 943
 944                let before_request = chrono::Utc::now();
 945                let response = Self::send_raw_llm_request(
 946                    request,
 947                    client,
 948                    llm_token,
 949                    app_version,
 950                    #[cfg(feature = "llm-response-cache")]
 951                    llm_response_cache
 952                ).await;
 953                let request_time = chrono::Utc::now() - before_request;
 954
 955                log::trace!("Got edit prediction response");
 956
 957                if let Some(debug_response_tx) = debug_response_tx {
 958                    debug_response_tx
 959                        .send((
 960                            response
 961                                .as_ref()
 962                                .map_err(|err| err.to_string())
 963                                .map(|response| response.0.clone()),
 964                            request_time,
 965                        ))
 966                        .ok();
 967                }
 968
 969                let (res, usage) = response?;
 970                let request_id = EditPredictionId(res.id.clone().into());
 971                let Some(mut output_text) = text_from_response(res) else {
 972                    return Ok((None, usage))
 973                };
 974
 975                if output_text.contains(CURSOR_MARKER) {
 976                    log::trace!("Stripping out {CURSOR_MARKER} from response");
 977                    output_text = output_text.replace(CURSOR_MARKER, "");
 978                }
 979
 980                let (edited_buffer_snapshot, edits) = match options.prompt_format {
 981                    PromptFormat::NumLinesUniDiff => {
 982                        crate::udiff::parse_diff(&output_text, |path| {
 983                            included_files
 984                                .iter()
 985                                .find_map(|(_, buffer, probe_path, ranges)| {
 986                                    if probe_path.as_ref() == path {
 987                                        Some((buffer, ranges.as_slice()))
 988                                    } else {
 989                                        None
 990                                    }
 991                                })
 992                        })
 993                        .await?
 994                    }
 995                    _ => {
 996                        bail!("unsupported prompt format {}", options.prompt_format)
 997                    }
 998                };
 999
1000                let edited_buffer = included_files
1001                    .iter()
1002                    .find_map(|(buffer, snapshot, _, _)| {
1003                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1004                            Some(buffer.clone())
1005                        } else {
1006                            None
1007                        }
1008                    })
1009                    .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
1010
1011                anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
1012            }
1013        });
1014
1015        cx.spawn({
1016            async move |this, cx| {
1017                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1018                    Self::handle_api_response(&this, request_task.await, cx)?
1019                else {
1020                    return Ok(None);
1021                };
1022
1023                // TODO telemetry: duration, etc
1024                Ok(
1025                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1026                        .await,
1027                )
1028            }
1029        })
1030    }
1031
1032    async fn send_raw_llm_request(
1033        request: open_ai::Request,
1034        client: Arc<Client>,
1035        llm_token: LlmApiToken,
1036        app_version: SemanticVersion,
1037        #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
1038            Arc<dyn LlmResponseCache>,
1039        >,
1040    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1041        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1042            http_client::Url::parse(&predict_edits_url)?
1043        } else {
1044            client
1045                .http_client()
1046                .build_zed_llm_url("/predict_edits/raw", &[])?
1047        };
1048
1049        #[cfg(feature = "llm-response-cache")]
1050        let cache_key = if let Some(cache) = llm_response_cache {
1051            let request_json = serde_json::to_string(&request)?;
1052            let key = cache.get_key(&url, &request_json);
1053
1054            if let Some(response_str) = cache.read_response(key) {
1055                return Ok((serde_json::from_str(&response_str)?, None));
1056            }
1057
1058            Some((cache, key))
1059        } else {
1060            None
1061        };
1062
1063        let (response, usage) = Self::send_api_request(
1064            |builder| {
1065                let req = builder
1066                    .uri(url.as_ref())
1067                    .body(serde_json::to_string(&request)?.into());
1068                Ok(req?)
1069            },
1070            client,
1071            llm_token,
1072            app_version,
1073        )
1074        .await?;
1075
1076        #[cfg(feature = "llm-response-cache")]
1077        if let Some((cache, key)) = cache_key {
1078            cache.write_response(key, &serde_json::to_string(&response)?);
1079        }
1080
1081        Ok((response, usage))
1082    }
1083
1084    fn handle_api_response<T>(
1085        this: &WeakEntity<Self>,
1086        response: Result<(T, Option<EditPredictionUsage>)>,
1087        cx: &mut gpui::AsyncApp,
1088    ) -> Result<T> {
1089        match response {
1090            Ok((data, usage)) => {
1091                if let Some(usage) = usage {
1092                    this.update(cx, |this, cx| {
1093                        this.user_store.update(cx, |user_store, cx| {
1094                            user_store.update_edit_prediction_usage(usage, cx);
1095                        });
1096                    })
1097                    .ok();
1098                }
1099                Ok(data)
1100            }
1101            Err(err) => {
1102                if err.is::<ZedUpdateRequiredError>() {
1103                    cx.update(|cx| {
1104                        this.update(cx, |this, _cx| {
1105                            this.update_required = true;
1106                        })
1107                        .ok();
1108
1109                        let error_message: SharedString = err.to_string().into();
1110                        show_app_notification(
1111                            NotificationId::unique::<ZedUpdateRequiredError>(),
1112                            cx,
1113                            move |cx| {
1114                                cx.new(|cx| {
1115                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1116                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1117                                })
1118                            },
1119                        );
1120                    })
1121                    .ok();
1122                }
1123                Err(err)
1124            }
1125        }
1126    }
1127
1128    async fn send_api_request<Res>(
1129        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1130        client: Arc<Client>,
1131        llm_token: LlmApiToken,
1132        app_version: SemanticVersion,
1133    ) -> Result<(Res, Option<EditPredictionUsage>)>
1134    where
1135        Res: DeserializeOwned,
1136    {
1137        let http_client = client.http_client();
1138        let mut token = llm_token.acquire(&client).await?;
1139        let mut did_retry = false;
1140
1141        loop {
1142            let request_builder = http_client::Request::builder().method(Method::POST);
1143
1144            let request = build(
1145                request_builder
1146                    .header("Content-Type", "application/json")
1147                    .header("Authorization", format!("Bearer {}", token))
1148                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1149            )?;
1150
1151            let mut response = http_client.send(request).await?;
1152
1153            if let Some(minimum_required_version) = response
1154                .headers()
1155                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1156                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1157            {
1158                anyhow::ensure!(
1159                    app_version >= minimum_required_version,
1160                    ZedUpdateRequiredError {
1161                        minimum_version: minimum_required_version
1162                    }
1163                );
1164            }
1165
1166            if response.status().is_success() {
1167                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1168
1169                let mut body = Vec::new();
1170                response.body_mut().read_to_end(&mut body).await?;
1171                return Ok((serde_json::from_slice(&body)?, usage));
1172            } else if !did_retry
1173                && response
1174                    .headers()
1175                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1176                    .is_some()
1177            {
1178                did_retry = true;
1179                token = llm_token.refresh(&client).await?;
1180            } else {
1181                let mut body = String::new();
1182                response.body_mut().read_to_string(&mut body).await?;
1183                anyhow::bail!(
1184                    "Request failed with status: {:?}\nBody: {}",
1185                    response.status(),
1186                    body
1187                );
1188            }
1189        }
1190    }
1191
1192    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1193    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1194
1195    // Refresh the related excerpts when the user just beguns editing after
1196    // an idle period, and after they pause editing.
1197    fn refresh_context_if_needed(
1198        &mut self,
1199        project: &Entity<Project>,
1200        buffer: &Entity<language::Buffer>,
1201        cursor_position: language::Anchor,
1202        cx: &mut Context<Self>,
1203    ) {
1204        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1205            return;
1206        }
1207
1208        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1209            return;
1210        };
1211
1212        let now = Instant::now();
1213        let was_idle = zeta_project
1214            .refresh_context_timestamp
1215            .map_or(true, |timestamp| {
1216                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1217            });
1218        zeta_project.refresh_context_timestamp = Some(now);
1219        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1220            let buffer = buffer.clone();
1221            let project = project.clone();
1222            async move |this, cx| {
1223                if was_idle {
1224                    log::debug!("refetching edit prediction context after idle");
1225                } else {
1226                    cx.background_executor()
1227                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1228                        .await;
1229                    log::debug!("refetching edit prediction context after pause");
1230                }
1231                this.update(cx, |this, cx| {
1232                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1233
1234                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1235                        zeta_project.refresh_context_task = Some(task.log_err());
1236                    };
1237                })
1238                .ok()
1239            }
1240        }));
1241    }
1242
1243    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1244    // and avoid spawning more than one concurrent task.
1245    pub fn refresh_context(
1246        &mut self,
1247        project: Entity<Project>,
1248        buffer: Entity<language::Buffer>,
1249        cursor_position: language::Anchor,
1250        cx: &mut Context<Self>,
1251    ) -> Task<Result<()>> {
1252        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1253            return Task::ready(anyhow::Ok(()));
1254        };
1255
1256        let ContextMode::Agentic(options) = &self.options().context else {
1257            return Task::ready(anyhow::Ok(()));
1258        };
1259
1260        let snapshot = buffer.read(cx).snapshot();
1261        let cursor_point = cursor_position.to_point(&snapshot);
1262        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1263            cursor_point,
1264            &snapshot,
1265            &options.excerpt,
1266            None,
1267        ) else {
1268            return Task::ready(Ok(()));
1269        };
1270
1271        let app_version = AppVersion::global(cx);
1272        let client = self.client.clone();
1273        let llm_token = self.llm_token.clone();
1274        let debug_tx = self.debug_tx.clone();
1275        let current_file_path: Arc<Path> = snapshot
1276            .file()
1277            .map(|f| f.full_path(cx).into())
1278            .unwrap_or_else(|| Path::new("untitled").into());
1279
1280        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1281            predict_edits_v3::PlanContextRetrievalRequest {
1282                excerpt: cursor_excerpt.text(&snapshot).body,
1283                excerpt_path: current_file_path,
1284                excerpt_line_range: cursor_excerpt.line_range,
1285                cursor_file_max_row: Line(snapshot.max_point().row),
1286                events: zeta_project
1287                    .events
1288                    .iter()
1289                    .filter_map(|ev| ev.to_request_event(cx))
1290                    .collect(),
1291            },
1292        ) {
1293            Ok(prompt) => prompt,
1294            Err(err) => {
1295                return Task::ready(Err(err));
1296            }
1297        };
1298
1299        if let Some(debug_tx) = &debug_tx {
1300            debug_tx
1301                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1302                    ZetaContextRetrievalStartedDebugInfo {
1303                        project: project.clone(),
1304                        timestamp: Instant::now(),
1305                        search_prompt: prompt.clone(),
1306                    },
1307                ))
1308                .ok();
1309        }
1310
1311        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1312            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1313                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1314            );
1315
1316            let description = schema
1317                .get("description")
1318                .and_then(|description| description.as_str())
1319                .unwrap()
1320                .to_string();
1321
1322            (schema.into(), description)
1323        });
1324
1325        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1326
1327        let request = open_ai::Request {
1328            model: MODEL_ID.clone(),
1329            messages: vec![open_ai::RequestMessage::User {
1330                content: open_ai::MessageContent::Plain(prompt),
1331            }],
1332            stream: false,
1333            max_completion_tokens: None,
1334            stop: Default::default(),
1335            temperature: 0.7,
1336            tool_choice: None,
1337            parallel_tool_calls: None,
1338            tools: vec![open_ai::ToolDefinition::Function {
1339                function: FunctionDefinition {
1340                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1341                    description: Some(tool_description),
1342                    parameters: Some(tool_schema),
1343                },
1344            }],
1345            prompt_cache_key: None,
1346            reasoning_effort: None,
1347        };
1348
1349        #[cfg(feature = "llm-response-cache")]
1350        let llm_response_cache = self.llm_response_cache.clone();
1351
1352        cx.spawn(async move |this, cx| {
1353            log::trace!("Sending search planning request");
1354            let response = Self::send_raw_llm_request(
1355                request,
1356                client,
1357                llm_token,
1358                app_version,
1359                #[cfg(feature = "llm-response-cache")]
1360                llm_response_cache,
1361            )
1362            .await;
1363            let mut response = Self::handle_api_response(&this, response, cx)?;
1364            log::trace!("Got search planning response");
1365
1366            let choice = response
1367                .choices
1368                .pop()
1369                .context("No choices in retrieval response")?;
1370            let open_ai::RequestMessage::Assistant {
1371                content: _,
1372                tool_calls,
1373            } = choice.message
1374            else {
1375                anyhow::bail!("Retrieval response didn't include an assistant message");
1376            };
1377
1378            let mut queries: Vec<SearchToolQuery> = Vec::new();
1379            for tool_call in tool_calls {
1380                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1381                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1382                    log::warn!(
1383                        "Context retrieval response tried to call an unknown tool: {}",
1384                        function.name
1385                    );
1386
1387                    continue;
1388                }
1389
1390                let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1391                queries.extend(input.queries);
1392            }
1393
1394            if let Some(debug_tx) = &debug_tx {
1395                debug_tx
1396                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1397                        ZetaSearchQueryDebugInfo {
1398                            project: project.clone(),
1399                            timestamp: Instant::now(),
1400                            search_queries: queries.clone(),
1401                        },
1402                    ))
1403                    .ok();
1404            }
1405
1406            log::trace!("Running retrieval search: {queries:#?}");
1407
1408            let related_excerpts_result =
1409                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1410
1411            log::trace!("Search queries executed");
1412
1413            if let Some(debug_tx) = &debug_tx {
1414                debug_tx
1415                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1416                        ZetaContextRetrievalDebugInfo {
1417                            project: project.clone(),
1418                            timestamp: Instant::now(),
1419                        },
1420                    ))
1421                    .ok();
1422            }
1423
1424            this.update(cx, |this, _cx| {
1425                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1426                    return Ok(());
1427                };
1428                zeta_project.refresh_context_task.take();
1429                if let Some(debug_tx) = &this.debug_tx {
1430                    debug_tx
1431                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1432                            ZetaContextRetrievalDebugInfo {
1433                                project,
1434                                timestamp: Instant::now(),
1435                            },
1436                        ))
1437                        .ok();
1438                }
1439                match related_excerpts_result {
1440                    Ok(excerpts) => {
1441                        zeta_project.context = Some(excerpts);
1442                        Ok(())
1443                    }
1444                    Err(error) => Err(error),
1445                }
1446            })?
1447        })
1448    }
1449
1450    fn gather_nearby_diagnostics(
1451        cursor_offset: usize,
1452        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1453        snapshot: &BufferSnapshot,
1454        max_diagnostics_bytes: usize,
1455    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1456        // TODO: Could make this more efficient
1457        let mut diagnostic_groups = Vec::new();
1458        for (language_server_id, diagnostics) in diagnostic_sets {
1459            let mut groups = Vec::new();
1460            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1461            diagnostic_groups.extend(
1462                groups
1463                    .into_iter()
1464                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1465            );
1466        }
1467
1468        // sort by proximity to cursor
1469        diagnostic_groups.sort_by_key(|group| {
1470            let range = &group.entries[group.primary_ix].range;
1471            if range.start >= cursor_offset {
1472                range.start - cursor_offset
1473            } else if cursor_offset >= range.end {
1474                cursor_offset - range.end
1475            } else {
1476                (cursor_offset - range.start).min(range.end - cursor_offset)
1477            }
1478        });
1479
1480        let mut results = Vec::new();
1481        let mut diagnostic_groups_truncated = false;
1482        let mut diagnostics_byte_count = 0;
1483        for group in diagnostic_groups {
1484            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1485            diagnostics_byte_count += raw_value.get().len();
1486            if diagnostics_byte_count > max_diagnostics_bytes {
1487                diagnostic_groups_truncated = true;
1488                break;
1489            }
1490            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1491        }
1492
1493        (results, diagnostic_groups_truncated)
1494    }
1495
1496    // TODO: Dedupe with similar code in request_prediction?
1497    pub fn cloud_request_for_zeta_cli(
1498        &mut self,
1499        project: &Entity<Project>,
1500        buffer: &Entity<Buffer>,
1501        position: language::Anchor,
1502        cx: &mut Context<Self>,
1503    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1504        let project_state = self.projects.get(&project.entity_id());
1505
1506        let index_state = project_state.map(|state| {
1507            state
1508                .syntax_index
1509                .read_with(cx, |index, _cx| index.state().clone())
1510        });
1511        let options = self.options.clone();
1512        let snapshot = buffer.read(cx).snapshot();
1513        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1514            return Task::ready(Err(anyhow!("No file path for excerpt")));
1515        };
1516        let worktree_snapshots = project
1517            .read(cx)
1518            .worktrees(cx)
1519            .map(|worktree| worktree.read(cx).snapshot())
1520            .collect::<Vec<_>>();
1521
1522        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1523            let mut path = f.worktree.read(cx).absolutize(&f.path);
1524            if path.pop() { Some(path) } else { None }
1525        });
1526
1527        cx.background_spawn(async move {
1528            let index_state = if let Some(index_state) = index_state {
1529                Some(index_state.lock_owned().await)
1530            } else {
1531                None
1532            };
1533
1534            let cursor_point = position.to_point(&snapshot);
1535
1536            let debug_info = true;
1537            EditPredictionContext::gather_context(
1538                cursor_point,
1539                &snapshot,
1540                parent_abs_path.as_deref(),
1541                match &options.context {
1542                    ContextMode::Agentic(_) => {
1543                        // TODO
1544                        panic!("Llm mode not supported in zeta cli yet");
1545                    }
1546                    ContextMode::Syntax(edit_prediction_context_options) => {
1547                        edit_prediction_context_options
1548                    }
1549                },
1550                index_state.as_deref(),
1551            )
1552            .context("Failed to select excerpt")
1553            .map(|context| {
1554                make_syntax_context_cloud_request(
1555                    excerpt_path.into(),
1556                    context,
1557                    // TODO pass everything
1558                    Vec::new(),
1559                    false,
1560                    Vec::new(),
1561                    false,
1562                    None,
1563                    debug_info,
1564                    &worktree_snapshots,
1565                    index_state.as_deref(),
1566                    Some(options.max_prompt_bytes),
1567                    options.prompt_format,
1568                )
1569            })
1570        })
1571    }
1572
1573    pub fn wait_for_initial_indexing(
1574        &mut self,
1575        project: &Entity<Project>,
1576        cx: &mut App,
1577    ) -> Task<Result<()>> {
1578        let zeta_project = self.get_or_init_zeta_project(project, cx);
1579        zeta_project
1580            .syntax_index
1581            .read(cx)
1582            .wait_for_initial_file_indexing(cx)
1583    }
1584}
1585
1586pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1587    let choice = res.choices.pop()?;
1588    let output_text = match choice.message {
1589        open_ai::RequestMessage::Assistant {
1590            content: Some(open_ai::MessageContent::Plain(content)),
1591            ..
1592        } => content,
1593        open_ai::RequestMessage::Assistant {
1594            content: Some(open_ai::MessageContent::Multipart(mut content)),
1595            ..
1596        } => {
1597            if content.is_empty() {
1598                log::error!("No output from Baseten completion response");
1599                return None;
1600            }
1601
1602            match content.remove(0) {
1603                open_ai::MessagePart::Text { text } => text,
1604                open_ai::MessagePart::Image { .. } => {
1605                    log::error!("Expected text, got an image");
1606                    return None;
1607                }
1608            }
1609        }
1610        _ => {
1611            log::error!("Invalid response message: {:?}", choice.message);
1612            return None;
1613        }
1614    };
1615    Some(output_text)
1616}
1617
1618#[derive(Error, Debug)]
1619#[error(
1620    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1621)]
1622pub struct ZedUpdateRequiredError {
1623    minimum_version: SemanticVersion,
1624}
1625
1626fn make_syntax_context_cloud_request(
1627    excerpt_path: Arc<Path>,
1628    context: EditPredictionContext,
1629    events: Vec<predict_edits_v3::Event>,
1630    can_collect_data: bool,
1631    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1632    diagnostic_groups_truncated: bool,
1633    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1634    debug_info: bool,
1635    worktrees: &Vec<worktree::Snapshot>,
1636    index_state: Option<&SyntaxIndexState>,
1637    prompt_max_bytes: Option<usize>,
1638    prompt_format: PromptFormat,
1639) -> predict_edits_v3::PredictEditsRequest {
1640    let mut signatures = Vec::new();
1641    let mut declaration_to_signature_index = HashMap::default();
1642    let mut referenced_declarations = Vec::new();
1643
1644    for snippet in context.declarations {
1645        let project_entry_id = snippet.declaration.project_entry_id();
1646        let Some(path) = worktrees.iter().find_map(|worktree| {
1647            worktree.entry_for_id(project_entry_id).map(|entry| {
1648                let mut full_path = RelPathBuf::new();
1649                full_path.push(worktree.root_name());
1650                full_path.push(&entry.path);
1651                full_path
1652            })
1653        }) else {
1654            continue;
1655        };
1656
1657        let parent_index = index_state.and_then(|index_state| {
1658            snippet.declaration.parent().and_then(|parent| {
1659                add_signature(
1660                    parent,
1661                    &mut declaration_to_signature_index,
1662                    &mut signatures,
1663                    index_state,
1664                )
1665            })
1666        });
1667
1668        let (text, text_is_truncated) = snippet.declaration.item_text();
1669        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1670            path: path.as_std_path().into(),
1671            text: text.into(),
1672            range: snippet.declaration.item_line_range(),
1673            text_is_truncated,
1674            signature_range: snippet.declaration.signature_range_in_item_text(),
1675            parent_index,
1676            signature_score: snippet.score(DeclarationStyle::Signature),
1677            declaration_score: snippet.score(DeclarationStyle::Declaration),
1678            score_components: snippet.components,
1679        });
1680    }
1681
1682    let excerpt_parent = index_state.and_then(|index_state| {
1683        context
1684            .excerpt
1685            .parent_declarations
1686            .last()
1687            .and_then(|(parent, _)| {
1688                add_signature(
1689                    *parent,
1690                    &mut declaration_to_signature_index,
1691                    &mut signatures,
1692                    index_state,
1693                )
1694            })
1695    });
1696
1697    predict_edits_v3::PredictEditsRequest {
1698        excerpt_path,
1699        excerpt: context.excerpt_text.body,
1700        excerpt_line_range: context.excerpt.line_range,
1701        excerpt_range: context.excerpt.range,
1702        cursor_point: predict_edits_v3::Point {
1703            line: predict_edits_v3::Line(context.cursor_point.row),
1704            column: context.cursor_point.column,
1705        },
1706        referenced_declarations,
1707        included_files: vec![],
1708        signatures,
1709        excerpt_parent,
1710        events,
1711        can_collect_data,
1712        diagnostic_groups,
1713        diagnostic_groups_truncated,
1714        git_info,
1715        debug_info,
1716        prompt_max_bytes,
1717        prompt_format,
1718    }
1719}
1720
1721fn add_signature(
1722    declaration_id: DeclarationId,
1723    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1724    signatures: &mut Vec<Signature>,
1725    index: &SyntaxIndexState,
1726) -> Option<usize> {
1727    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1728        return Some(*signature_index);
1729    }
1730    let Some(parent_declaration) = index.declaration(declaration_id) else {
1731        log::error!("bug: missing parent declaration");
1732        return None;
1733    };
1734    let parent_index = parent_declaration.parent().and_then(|parent| {
1735        add_signature(parent, declaration_to_signature_index, signatures, index)
1736    });
1737    let (text, text_is_truncated) = parent_declaration.signature_text();
1738    let signature_index = signatures.len();
1739    signatures.push(Signature {
1740        text: text.into(),
1741        text_is_truncated,
1742        parent_index,
1743        range: parent_declaration.signature_line_range(),
1744    });
1745    declaration_to_signature_index.insert(declaration_id, signature_index);
1746    Some(signature_index)
1747}
1748
1749#[cfg(test)]
1750mod tests {
1751    use std::{path::Path, sync::Arc};
1752
1753    use client::UserStore;
1754    use clock::FakeSystemClock;
1755    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1756    use futures::{
1757        AsyncReadExt, StreamExt,
1758        channel::{mpsc, oneshot},
1759    };
1760    use gpui::{
1761        Entity, TestAppContext,
1762        http_client::{FakeHttpClient, Response},
1763        prelude::*,
1764    };
1765    use indoc::indoc;
1766    use language::OffsetRangeExt as _;
1767    use open_ai::Usage;
1768    use pretty_assertions::{assert_eq, assert_matches};
1769    use project::{FakeFs, Project};
1770    use serde_json::json;
1771    use settings::SettingsStore;
1772    use util::path;
1773    use uuid::Uuid;
1774
1775    use crate::{BufferEditPrediction, Zeta};
1776
1777    #[gpui::test]
1778    async fn test_current_state(cx: &mut TestAppContext) {
1779        let (zeta, mut req_rx) = init_test(cx);
1780        let fs = FakeFs::new(cx.executor());
1781        fs.insert_tree(
1782            "/root",
1783            json!({
1784                "1.txt": "Hello!\nHow\nBye\n",
1785                "2.txt": "Hola!\nComo\nAdios\n"
1786            }),
1787        )
1788        .await;
1789        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1790
1791        zeta.update(cx, |zeta, cx| {
1792            zeta.register_project(&project, cx);
1793        });
1794
1795        let buffer1 = project
1796            .update(cx, |project, cx| {
1797                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1798                project.open_buffer(path, cx)
1799            })
1800            .await
1801            .unwrap();
1802        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1803        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1804
1805        // Prediction for current file
1806
1807        let prediction_task = zeta.update(cx, |zeta, cx| {
1808            zeta.refresh_prediction(&project, &buffer1, position, cx)
1809        });
1810        let (_request, respond_tx) = req_rx.next().await.unwrap();
1811
1812        respond_tx
1813            .send(model_response(indoc! {r"
1814                --- a/root/1.txt
1815                +++ b/root/1.txt
1816                @@ ... @@
1817                 Hello!
1818                -How
1819                +How are you?
1820                 Bye
1821            "}))
1822            .unwrap();
1823        prediction_task.await.unwrap();
1824
1825        zeta.read_with(cx, |zeta, cx| {
1826            let prediction = zeta
1827                .current_prediction_for_buffer(&buffer1, &project, cx)
1828                .unwrap();
1829            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1830        });
1831
1832        // Context refresh
1833        let refresh_task = zeta.update(cx, |zeta, cx| {
1834            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1835        });
1836        let (_request, respond_tx) = req_rx.next().await.unwrap();
1837        respond_tx
1838            .send(open_ai::Response {
1839                id: Uuid::new_v4().to_string(),
1840                object: "response".into(),
1841                created: 0,
1842                model: "model".into(),
1843                choices: vec![open_ai::Choice {
1844                    index: 0,
1845                    message: open_ai::RequestMessage::Assistant {
1846                        content: None,
1847                        tool_calls: vec![open_ai::ToolCall {
1848                            id: "search".into(),
1849                            content: open_ai::ToolCallContent::Function {
1850                                function: open_ai::FunctionContent {
1851                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1852                                        .to_string(),
1853                                    arguments: serde_json::to_string(&SearchToolInput {
1854                                        queries: Box::new([SearchToolQuery {
1855                                            glob: "root/2.txt".to_string(),
1856                                            syntax_node: vec![],
1857                                            content: Some(".".into()),
1858                                        }]),
1859                                    })
1860                                    .unwrap(),
1861                                },
1862                            },
1863                        }],
1864                    },
1865                    finish_reason: None,
1866                }],
1867                usage: Usage {
1868                    prompt_tokens: 0,
1869                    completion_tokens: 0,
1870                    total_tokens: 0,
1871                },
1872            })
1873            .unwrap();
1874        refresh_task.await.unwrap();
1875
1876        zeta.update(cx, |zeta, _cx| {
1877            zeta.discard_current_prediction(&project);
1878        });
1879
1880        // Prediction for another file
1881        let prediction_task = zeta.update(cx, |zeta, cx| {
1882            zeta.refresh_prediction(&project, &buffer1, position, cx)
1883        });
1884        let (_request, respond_tx) = req_rx.next().await.unwrap();
1885        respond_tx
1886            .send(model_response(indoc! {r#"
1887                --- a/root/2.txt
1888                +++ b/root/2.txt
1889                 Hola!
1890                -Como
1891                +Como estas?
1892                 Adios
1893            "#}))
1894            .unwrap();
1895        prediction_task.await.unwrap();
1896        zeta.read_with(cx, |zeta, cx| {
1897            let prediction = zeta
1898                .current_prediction_for_buffer(&buffer1, &project, cx)
1899                .unwrap();
1900            assert_matches!(
1901                prediction,
1902                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1903            );
1904        });
1905
1906        let buffer2 = project
1907            .update(cx, |project, cx| {
1908                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1909                project.open_buffer(path, cx)
1910            })
1911            .await
1912            .unwrap();
1913
1914        zeta.read_with(cx, |zeta, cx| {
1915            let prediction = zeta
1916                .current_prediction_for_buffer(&buffer2, &project, cx)
1917                .unwrap();
1918            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1919        });
1920    }
1921
1922    #[gpui::test]
1923    async fn test_simple_request(cx: &mut TestAppContext) {
1924        let (zeta, mut req_rx) = init_test(cx);
1925        let fs = FakeFs::new(cx.executor());
1926        fs.insert_tree(
1927            "/root",
1928            json!({
1929                "foo.md":  "Hello!\nHow\nBye\n"
1930            }),
1931        )
1932        .await;
1933        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1934
1935        let buffer = project
1936            .update(cx, |project, cx| {
1937                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1938                project.open_buffer(path, cx)
1939            })
1940            .await
1941            .unwrap();
1942        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1943        let position = snapshot.anchor_before(language::Point::new(1, 3));
1944
1945        let prediction_task = zeta.update(cx, |zeta, cx| {
1946            zeta.request_prediction(&project, &buffer, position, cx)
1947        });
1948
1949        let (_, respond_tx) = req_rx.next().await.unwrap();
1950
1951        // TODO Put back when we have a structured request again
1952        // assert_eq!(
1953        //     request.excerpt_path.as_ref(),
1954        //     Path::new(path!("root/foo.md"))
1955        // );
1956        // assert_eq!(
1957        //     request.cursor_point,
1958        //     Point {
1959        //         line: Line(1),
1960        //         column: 3
1961        //     }
1962        // );
1963
1964        respond_tx
1965            .send(model_response(indoc! { r"
1966                --- a/root/foo.md
1967                +++ b/root/foo.md
1968                @@ ... @@
1969                 Hello!
1970                -How
1971                +How are you?
1972                 Bye
1973            "}))
1974            .unwrap();
1975
1976        let prediction = prediction_task.await.unwrap().unwrap();
1977
1978        assert_eq!(prediction.edits.len(), 1);
1979        assert_eq!(
1980            prediction.edits[0].0.to_point(&snapshot).start,
1981            language::Point::new(1, 3)
1982        );
1983        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1984    }
1985
1986    #[gpui::test]
1987    async fn test_request_events(cx: &mut TestAppContext) {
1988        let (zeta, mut req_rx) = init_test(cx);
1989        let fs = FakeFs::new(cx.executor());
1990        fs.insert_tree(
1991            "/root",
1992            json!({
1993                "foo.md": "Hello!\n\nBye\n"
1994            }),
1995        )
1996        .await;
1997        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1998
1999        let buffer = project
2000            .update(cx, |project, cx| {
2001                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2002                project.open_buffer(path, cx)
2003            })
2004            .await
2005            .unwrap();
2006
2007        zeta.update(cx, |zeta, cx| {
2008            zeta.register_buffer(&buffer, &project, cx);
2009        });
2010
2011        buffer.update(cx, |buffer, cx| {
2012            buffer.edit(vec![(7..7, "How")], None, cx);
2013        });
2014
2015        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2016        let position = snapshot.anchor_before(language::Point::new(1, 3));
2017
2018        let prediction_task = zeta.update(cx, |zeta, cx| {
2019            zeta.request_prediction(&project, &buffer, position, cx)
2020        });
2021
2022        let (request, respond_tx) = req_rx.next().await.unwrap();
2023
2024        let prompt = prompt_from_request(&request);
2025        assert!(
2026            prompt.contains(indoc! {"
2027            --- a/root/foo.md
2028            +++ b/root/foo.md
2029            @@ -1,3 +1,3 @@
2030             Hello!
2031            -
2032            +How
2033             Bye
2034        "}),
2035            "{prompt}"
2036        );
2037
2038        respond_tx
2039            .send(model_response(indoc! {r#"
2040                --- a/root/foo.md
2041                +++ b/root/foo.md
2042                @@ ... @@
2043                 Hello!
2044                -How
2045                +How are you?
2046                 Bye
2047            "#}))
2048            .unwrap();
2049
2050        let prediction = prediction_task.await.unwrap().unwrap();
2051
2052        assert_eq!(prediction.edits.len(), 1);
2053        assert_eq!(
2054            prediction.edits[0].0.to_point(&snapshot).start,
2055            language::Point::new(1, 3)
2056        );
2057        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2058    }
2059
2060    // Skipped until we start including diagnostics in prompt
2061    // #[gpui::test]
2062    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2063    //     let (zeta, mut req_rx) = init_test(cx);
2064    //     let fs = FakeFs::new(cx.executor());
2065    //     fs.insert_tree(
2066    //         "/root",
2067    //         json!({
2068    //             "foo.md": "Hello!\nBye"
2069    //         }),
2070    //     )
2071    //     .await;
2072    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2073
2074    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2075    //     let diagnostic = lsp::Diagnostic {
2076    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2077    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2078    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2079    //         ..Default::default()
2080    //     };
2081
2082    //     project.update(cx, |project, cx| {
2083    //         project.lsp_store().update(cx, |lsp_store, cx| {
2084    //             // Create some diagnostics
2085    //             lsp_store
2086    //                 .update_diagnostics(
2087    //                     LanguageServerId(0),
2088    //                     lsp::PublishDiagnosticsParams {
2089    //                         uri: path_to_buffer_uri.clone(),
2090    //                         diagnostics: vec![diagnostic],
2091    //                         version: None,
2092    //                     },
2093    //                     None,
2094    //                     language::DiagnosticSourceKind::Pushed,
2095    //                     &[],
2096    //                     cx,
2097    //                 )
2098    //                 .unwrap();
2099    //         });
2100    //     });
2101
2102    //     let buffer = project
2103    //         .update(cx, |project, cx| {
2104    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2105    //             project.open_buffer(path, cx)
2106    //         })
2107    //         .await
2108    //         .unwrap();
2109
2110    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2111    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2112
2113    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2114    //         zeta.request_prediction(&project, &buffer, position, cx)
2115    //     });
2116
2117    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2118
2119    //     assert_eq!(request.diagnostic_groups.len(), 1);
2120    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2121    //         .unwrap();
2122    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2123    //     assert_eq!(
2124    //         value,
2125    //         json!({
2126    //             "entries": [{
2127    //                 "range": {
2128    //                     "start": 8,
2129    //                     "end": 10
2130    //                 },
2131    //                 "diagnostic": {
2132    //                     "source": null,
2133    //                     "code": null,
2134    //                     "code_description": null,
2135    //                     "severity": 1,
2136    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2137    //                     "markdown": null,
2138    //                     "group_id": 0,
2139    //                     "is_primary": true,
2140    //                     "is_disk_based": false,
2141    //                     "is_unnecessary": false,
2142    //                     "source_kind": "Pushed",
2143    //                     "data": null,
2144    //                     "underline": true
2145    //                 }
2146    //             }],
2147    //             "primary_ix": 0
2148    //         })
2149    //     );
2150    // }
2151
2152    fn model_response(text: &str) -> open_ai::Response {
2153        open_ai::Response {
2154            id: Uuid::new_v4().to_string(),
2155            object: "response".into(),
2156            created: 0,
2157            model: "model".into(),
2158            choices: vec![open_ai::Choice {
2159                index: 0,
2160                message: open_ai::RequestMessage::Assistant {
2161                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2162                    tool_calls: vec![],
2163                },
2164                finish_reason: None,
2165            }],
2166            usage: Usage {
2167                prompt_tokens: 0,
2168                completion_tokens: 0,
2169                total_tokens: 0,
2170            },
2171        }
2172    }
2173
2174    fn prompt_from_request(request: &open_ai::Request) -> &str {
2175        assert_eq!(request.messages.len(), 1);
2176        let open_ai::RequestMessage::User {
2177            content: open_ai::MessageContent::Plain(content),
2178            ..
2179        } = &request.messages[0]
2180        else {
2181            panic!(
2182                "Request does not have single user message of type Plain. {:#?}",
2183                request
2184            );
2185        };
2186        content
2187    }
2188
2189    fn init_test(
2190        cx: &mut TestAppContext,
2191    ) -> (
2192        Entity<Zeta>,
2193        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2194    ) {
2195        cx.update(move |cx| {
2196            let settings_store = SettingsStore::test(cx);
2197            cx.set_global(settings_store);
2198            zlog::init_test();
2199
2200            let (req_tx, req_rx) = mpsc::unbounded();
2201
2202            let http_client = FakeHttpClient::create({
2203                move |req| {
2204                    let uri = req.uri().path().to_string();
2205                    let mut body = req.into_body();
2206                    let req_tx = req_tx.clone();
2207                    async move {
2208                        let resp = match uri.as_str() {
2209                            "/client/llm_tokens" => serde_json::to_string(&json!({
2210                                "token": "test"
2211                            }))
2212                            .unwrap(),
2213                            "/predict_edits/raw" => {
2214                                let mut buf = Vec::new();
2215                                body.read_to_end(&mut buf).await.ok();
2216                                let req = serde_json::from_slice(&buf).unwrap();
2217
2218                                let (res_tx, res_rx) = oneshot::channel();
2219                                req_tx.unbounded_send((req, res_tx)).unwrap();
2220                                serde_json::to_string(&res_rx.await?).unwrap()
2221                            }
2222                            _ => {
2223                                panic!("Unexpected path: {}", uri)
2224                            }
2225                        };
2226
2227                        Ok(Response::builder().body(resp.into()).unwrap())
2228                    }
2229                }
2230            });
2231
2232            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2233            client.cloud_client().set_credentials(1, "test".into());
2234
2235            language_model::init(client.clone(), cx);
2236
2237            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2238            let zeta = Zeta::global(&client, &user_store, cx);
2239
2240            (zeta, req_rx)
2241        })
2242    }
2243}