zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50
  51use crate::merge_excerpts::merge_excerpts;
  52use crate::prediction::EditPrediction;
  53pub use crate::prediction::EditPredictionId;
  54pub use provider::ZetaEditPredictionProvider;
  55
  56/// Maximum number of events to track.
  57const MAX_EVENT_COUNT: usize = 16;
  58
  59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  60    max_bytes: 512,
  61    min_bytes: 128,
  62    target_before_cursor_over_total_bytes: 0.5,
  63};
  64
  65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  66    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  67
  68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  69    excerpt: DEFAULT_EXCERPT_OPTIONS,
  70};
  71
  72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  73    EditPredictionContextOptions {
  74        use_imports: true,
  75        max_retrieved_declarations: 0,
  76        excerpt: DEFAULT_EXCERPT_OPTIONS,
  77        score: EditPredictionScoreOptions {
  78            omit_excerpt_overlaps: true,
  79        },
  80    };
  81
  82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  83    context: DEFAULT_CONTEXT_OPTIONS,
  84    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  85    max_diagnostic_bytes: 2048,
  86    prompt_format: PromptFormat::DEFAULT,
  87    file_indexing_parallelism: 1,
  88    buffer_change_grouping_interval: Duration::from_secs(1),
  89};
  90
  91static USE_OLLAMA: LazyLock<bool> =
  92    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  93static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  94    env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
  95        "qwen3-coder:30b".to_string()
  96    } else {
  97        "yqvev8r3".to_string()
  98    })
  99});
 100static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 101    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 102        if *USE_OLLAMA {
 103            Some("http://localhost:11434/v1/chat/completions".into())
 104        } else {
 105            None
 106        }
 107    })
 108});
 109
 110pub struct Zeta2FeatureFlag;
 111
 112impl FeatureFlag for Zeta2FeatureFlag {
 113    const NAME: &'static str = "zeta2";
 114
 115    fn enabled_for_staff() -> bool {
 116        false
 117    }
 118}
 119
 120#[derive(Clone)]
 121struct ZetaGlobal(Entity<Zeta>);
 122
 123impl Global for ZetaGlobal {}
 124
 125pub struct Zeta {
 126    client: Arc<Client>,
 127    user_store: Entity<UserStore>,
 128    llm_token: LlmApiToken,
 129    _llm_token_subscription: Subscription,
 130    projects: HashMap<EntityId, ZetaProject>,
 131    options: ZetaOptions,
 132    update_required: bool,
 133    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 134}
 135
 136#[derive(Debug, Clone, PartialEq)]
 137pub struct ZetaOptions {
 138    pub context: ContextMode,
 139    pub max_prompt_bytes: usize,
 140    pub max_diagnostic_bytes: usize,
 141    pub prompt_format: predict_edits_v3::PromptFormat,
 142    pub file_indexing_parallelism: usize,
 143    pub buffer_change_grouping_interval: Duration,
 144}
 145
 146#[derive(Debug, Clone, PartialEq)]
 147pub enum ContextMode {
 148    Agentic(AgenticContextOptions),
 149    Syntax(EditPredictionContextOptions),
 150}
 151
 152#[derive(Debug, Clone, PartialEq)]
 153pub struct AgenticContextOptions {
 154    pub excerpt: EditPredictionExcerptOptions,
 155}
 156
 157impl ContextMode {
 158    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 159        match self {
 160            ContextMode::Agentic(options) => &options.excerpt,
 161            ContextMode::Syntax(options) => &options.excerpt,
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub enum ZetaDebugInfo {
 168    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 169    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 170    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 171    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 172    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 173}
 174
 175#[derive(Debug)]
 176pub struct ZetaContextRetrievalStartedDebugInfo {
 177    pub project: Entity<Project>,
 178    pub timestamp: Instant,
 179    pub search_prompt: String,
 180}
 181
 182#[derive(Debug)]
 183pub struct ZetaContextRetrievalDebugInfo {
 184    pub project: Entity<Project>,
 185    pub timestamp: Instant,
 186}
 187
 188#[derive(Debug)]
 189pub struct ZetaEditPredictionDebugInfo {
 190    pub request: predict_edits_v3::PredictEditsRequest,
 191    pub retrieval_time: TimeDelta,
 192    pub buffer: WeakEntity<Buffer>,
 193    pub position: language::Anchor,
 194    pub local_prompt: Result<String, String>,
 195    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 196}
 197
 198#[derive(Debug)]
 199pub struct ZetaSearchQueryDebugInfo {
 200    pub project: Entity<Project>,
 201    pub timestamp: Instant,
 202    pub search_queries: Vec<SearchToolQuery>,
 203}
 204
 205pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 206
 207struct ZetaProject {
 208    syntax_index: Entity<SyntaxIndex>,
 209    events: VecDeque<Event>,
 210    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 211    current_prediction: Option<CurrentEditPrediction>,
 212    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 213    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 214    refresh_context_debounce_task: Option<Task<Option<()>>>,
 215    refresh_context_timestamp: Option<Instant>,
 216}
 217
 218#[derive(Debug, Clone)]
 219struct CurrentEditPrediction {
 220    pub requested_by_buffer_id: EntityId,
 221    pub prediction: EditPrediction,
 222}
 223
 224impl CurrentEditPrediction {
 225    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 226        let Some(new_edits) = self
 227            .prediction
 228            .interpolate(&self.prediction.buffer.read(cx))
 229        else {
 230            return false;
 231        };
 232
 233        if self.prediction.buffer != old_prediction.prediction.buffer {
 234            return true;
 235        }
 236
 237        let Some(old_edits) = old_prediction
 238            .prediction
 239            .interpolate(&old_prediction.prediction.buffer.read(cx))
 240        else {
 241            return true;
 242        };
 243
 244        // This reduces the occurrence of UI thrash from replacing edits
 245        //
 246        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 247        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 248            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 249            && old_edits.len() == 1
 250            && new_edits.len() == 1
 251        {
 252            let (old_range, old_text) = &old_edits[0];
 253            let (new_range, new_text) = &new_edits[0];
 254            new_range == old_range && new_text.starts_with(old_text.as_ref())
 255        } else {
 256            true
 257        }
 258    }
 259}
 260
 261/// A prediction from the perspective of a buffer.
 262#[derive(Debug)]
 263enum BufferEditPrediction<'a> {
 264    Local { prediction: &'a EditPrediction },
 265    Jump { prediction: &'a EditPrediction },
 266}
 267
 268struct RegisteredBuffer {
 269    snapshot: BufferSnapshot,
 270    _subscriptions: [gpui::Subscription; 2],
 271}
 272
 273#[derive(Clone)]
 274pub enum Event {
 275    BufferChange {
 276        old_snapshot: BufferSnapshot,
 277        new_snapshot: BufferSnapshot,
 278        timestamp: Instant,
 279    },
 280}
 281
 282impl Event {
 283    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 284        match self {
 285            Event::BufferChange {
 286                old_snapshot,
 287                new_snapshot,
 288                ..
 289            } => {
 290                let path = new_snapshot.file().map(|f| f.full_path(cx));
 291
 292                let old_path = old_snapshot.file().and_then(|f| {
 293                    let old_path = f.full_path(cx);
 294                    if Some(&old_path) != path.as_ref() {
 295                        Some(old_path)
 296                    } else {
 297                        None
 298                    }
 299                });
 300
 301                // TODO [zeta2] move to bg?
 302                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 303
 304                if path == old_path && diff.is_empty() {
 305                    None
 306                } else {
 307                    Some(predict_edits_v3::Event::BufferChange {
 308                        old_path,
 309                        path,
 310                        diff,
 311                        //todo: Actually detect if this edit was predicted or not
 312                        predicted: false,
 313                    })
 314                }
 315            }
 316        }
 317    }
 318}
 319
 320impl Zeta {
 321    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 322        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 323    }
 324
 325    pub fn global(
 326        client: &Arc<Client>,
 327        user_store: &Entity<UserStore>,
 328        cx: &mut App,
 329    ) -> Entity<Self> {
 330        cx.try_global::<ZetaGlobal>()
 331            .map(|global| global.0.clone())
 332            .unwrap_or_else(|| {
 333                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 334                cx.set_global(ZetaGlobal(zeta.clone()));
 335                zeta
 336            })
 337    }
 338
 339    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 340        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 341
 342        Self {
 343            projects: HashMap::default(),
 344            client,
 345            user_store,
 346            options: DEFAULT_OPTIONS,
 347            llm_token: LlmApiToken::default(),
 348            _llm_token_subscription: cx.subscribe(
 349                &refresh_llm_token_listener,
 350                |this, _listener, _event, cx| {
 351                    let client = this.client.clone();
 352                    let llm_token = this.llm_token.clone();
 353                    cx.spawn(async move |_this, _cx| {
 354                        llm_token.refresh(&client).await?;
 355                        anyhow::Ok(())
 356                    })
 357                    .detach_and_log_err(cx);
 358                },
 359            ),
 360            update_required: false,
 361            debug_tx: None,
 362        }
 363    }
 364
 365    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 366        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 367        self.debug_tx = Some(debug_watch_tx);
 368        debug_watch_rx
 369    }
 370
 371    pub fn options(&self) -> &ZetaOptions {
 372        &self.options
 373    }
 374
 375    pub fn set_options(&mut self, options: ZetaOptions) {
 376        self.options = options;
 377    }
 378
 379    pub fn clear_history(&mut self) {
 380        for zeta_project in self.projects.values_mut() {
 381            zeta_project.events.clear();
 382        }
 383    }
 384
 385    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 386        self.projects
 387            .get(&project.entity_id())
 388            .map(|project| project.events.iter())
 389            .into_iter()
 390            .flatten()
 391    }
 392
 393    pub fn context_for_project(
 394        &self,
 395        project: &Entity<Project>,
 396    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 397        self.projects
 398            .get(&project.entity_id())
 399            .and_then(|project| {
 400                Some(
 401                    project
 402                        .context
 403                        .as_ref()?
 404                        .iter()
 405                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 406                )
 407            })
 408            .into_iter()
 409            .flatten()
 410    }
 411
 412    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 413        self.user_store.read(cx).edit_prediction_usage()
 414    }
 415
 416    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 417        self.get_or_init_zeta_project(project, cx);
 418    }
 419
 420    pub fn register_buffer(
 421        &mut self,
 422        buffer: &Entity<Buffer>,
 423        project: &Entity<Project>,
 424        cx: &mut Context<Self>,
 425    ) {
 426        let zeta_project = self.get_or_init_zeta_project(project, cx);
 427        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 428    }
 429
 430    fn get_or_init_zeta_project(
 431        &mut self,
 432        project: &Entity<Project>,
 433        cx: &mut App,
 434    ) -> &mut ZetaProject {
 435        self.projects
 436            .entry(project.entity_id())
 437            .or_insert_with(|| ZetaProject {
 438                syntax_index: cx.new(|cx| {
 439                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 440                }),
 441                events: VecDeque::new(),
 442                registered_buffers: HashMap::default(),
 443                current_prediction: None,
 444                context: None,
 445                refresh_context_task: None,
 446                refresh_context_debounce_task: None,
 447                refresh_context_timestamp: None,
 448            })
 449    }
 450
 451    fn register_buffer_impl<'a>(
 452        zeta_project: &'a mut ZetaProject,
 453        buffer: &Entity<Buffer>,
 454        project: &Entity<Project>,
 455        cx: &mut Context<Self>,
 456    ) -> &'a mut RegisteredBuffer {
 457        let buffer_id = buffer.entity_id();
 458        match zeta_project.registered_buffers.entry(buffer_id) {
 459            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 460            hash_map::Entry::Vacant(entry) => {
 461                let snapshot = buffer.read(cx).snapshot();
 462                let project_entity_id = project.entity_id();
 463                entry.insert(RegisteredBuffer {
 464                    snapshot,
 465                    _subscriptions: [
 466                        cx.subscribe(buffer, {
 467                            let project = project.downgrade();
 468                            move |this, buffer, event, cx| {
 469                                if let language::BufferEvent::Edited = event
 470                                    && let Some(project) = project.upgrade()
 471                                {
 472                                    this.report_changes_for_buffer(&buffer, &project, cx);
 473                                }
 474                            }
 475                        }),
 476                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 477                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 478                            else {
 479                                return;
 480                            };
 481                            zeta_project.registered_buffers.remove(&buffer_id);
 482                        }),
 483                    ],
 484                })
 485            }
 486        }
 487    }
 488
 489    fn report_changes_for_buffer(
 490        &mut self,
 491        buffer: &Entity<Buffer>,
 492        project: &Entity<Project>,
 493        cx: &mut Context<Self>,
 494    ) -> BufferSnapshot {
 495        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 496        let zeta_project = self.get_or_init_zeta_project(project, cx);
 497        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 498
 499        let new_snapshot = buffer.read(cx).snapshot();
 500        if new_snapshot.version != registered_buffer.snapshot.version {
 501            let old_snapshot =
 502                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 503            Self::push_event(
 504                zeta_project,
 505                buffer_change_grouping_interval,
 506                Event::BufferChange {
 507                    old_snapshot,
 508                    new_snapshot: new_snapshot.clone(),
 509                    timestamp: Instant::now(),
 510                },
 511            );
 512        }
 513
 514        new_snapshot
 515    }
 516
 517    fn push_event(
 518        zeta_project: &mut ZetaProject,
 519        buffer_change_grouping_interval: Duration,
 520        event: Event,
 521    ) {
 522        let events = &mut zeta_project.events;
 523
 524        if buffer_change_grouping_interval > Duration::ZERO
 525            && let Some(Event::BufferChange {
 526                new_snapshot: last_new_snapshot,
 527                timestamp: last_timestamp,
 528                ..
 529            }) = events.back_mut()
 530        {
 531            // Coalesce edits for the same buffer when they happen one after the other.
 532            let Event::BufferChange {
 533                old_snapshot,
 534                new_snapshot,
 535                timestamp,
 536            } = &event;
 537
 538            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 539                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 540                && old_snapshot.version == last_new_snapshot.version
 541            {
 542                *last_new_snapshot = new_snapshot.clone();
 543                *last_timestamp = *timestamp;
 544                return;
 545            }
 546        }
 547
 548        if events.len() >= MAX_EVENT_COUNT {
 549            // These are halved instead of popping to improve prompt caching.
 550            events.drain(..MAX_EVENT_COUNT / 2);
 551        }
 552
 553        events.push_back(event);
 554    }
 555
 556    fn current_prediction_for_buffer(
 557        &self,
 558        buffer: &Entity<Buffer>,
 559        project: &Entity<Project>,
 560        cx: &App,
 561    ) -> Option<BufferEditPrediction<'_>> {
 562        let project_state = self.projects.get(&project.entity_id())?;
 563
 564        let CurrentEditPrediction {
 565            requested_by_buffer_id,
 566            prediction,
 567        } = project_state.current_prediction.as_ref()?;
 568
 569        if prediction.targets_buffer(buffer.read(cx)) {
 570            Some(BufferEditPrediction::Local { prediction })
 571        } else if *requested_by_buffer_id == buffer.entity_id() {
 572            Some(BufferEditPrediction::Jump { prediction })
 573        } else {
 574            None
 575        }
 576    }
 577
 578    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 579        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 580            return;
 581        };
 582
 583        let Some(prediction) = project_state.current_prediction.take() else {
 584            return;
 585        };
 586        let request_id = prediction.prediction.id.to_string();
 587
 588        let client = self.client.clone();
 589        let llm_token = self.llm_token.clone();
 590        let app_version = AppVersion::global(cx);
 591        cx.spawn(async move |this, cx| {
 592            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 593                http_client::Url::parse(&predict_edits_url)?
 594            } else {
 595                client
 596                    .http_client()
 597                    .build_zed_llm_url("/predict_edits/accept", &[])?
 598            };
 599
 600            let response = cx
 601                .background_spawn(Self::send_api_request::<()>(
 602                    move |builder| {
 603                        let req = builder.uri(url.as_ref()).body(
 604                            serde_json::to_string(&AcceptEditPredictionBody {
 605                                request_id: request_id.clone(),
 606                            })?
 607                            .into(),
 608                        );
 609                        Ok(req?)
 610                    },
 611                    client,
 612                    llm_token,
 613                    app_version,
 614                ))
 615                .await;
 616
 617            Self::handle_api_response(&this, response, cx)?;
 618            anyhow::Ok(())
 619        })
 620        .detach_and_log_err(cx);
 621    }
 622
 623    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 624        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 625            project_state.current_prediction.take();
 626        };
 627    }
 628
 629    pub fn refresh_prediction(
 630        &mut self,
 631        project: &Entity<Project>,
 632        buffer: &Entity<Buffer>,
 633        position: language::Anchor,
 634        cx: &mut Context<Self>,
 635    ) -> Task<Result<()>> {
 636        let request_task = self.request_prediction(project, buffer, position, cx);
 637        let buffer = buffer.clone();
 638        let project = project.clone();
 639
 640        cx.spawn(async move |this, cx| {
 641            if let Some(prediction) = request_task.await? {
 642                this.update(cx, |this, cx| {
 643                    let project_state = this
 644                        .projects
 645                        .get_mut(&project.entity_id())
 646                        .context("Project not found")?;
 647
 648                    let new_prediction = CurrentEditPrediction {
 649                        requested_by_buffer_id: buffer.entity_id(),
 650                        prediction: prediction,
 651                    };
 652
 653                    if project_state
 654                        .current_prediction
 655                        .as_ref()
 656                        .is_none_or(|old_prediction| {
 657                            new_prediction.should_replace_prediction(&old_prediction, cx)
 658                        })
 659                    {
 660                        project_state.current_prediction = Some(new_prediction);
 661                    }
 662                    anyhow::Ok(())
 663                })??;
 664            }
 665            Ok(())
 666        })
 667    }
 668
 669    pub fn request_prediction(
 670        &mut self,
 671        project: &Entity<Project>,
 672        active_buffer: &Entity<Buffer>,
 673        position: language::Anchor,
 674        cx: &mut Context<Self>,
 675    ) -> Task<Result<Option<EditPrediction>>> {
 676        let project_state = self.projects.get(&project.entity_id());
 677
 678        let index_state = project_state.map(|state| {
 679            state
 680                .syntax_index
 681                .read_with(cx, |index, _cx| index.state().clone())
 682        });
 683        let options = self.options.clone();
 684        let active_snapshot = active_buffer.read(cx).snapshot();
 685        let Some(excerpt_path) = active_snapshot
 686            .file()
 687            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 688        else {
 689            return Task::ready(Err(anyhow!("No file path for excerpt")));
 690        };
 691        let client = self.client.clone();
 692        let llm_token = self.llm_token.clone();
 693        let app_version = AppVersion::global(cx);
 694        let worktree_snapshots = project
 695            .read(cx)
 696            .worktrees(cx)
 697            .map(|worktree| worktree.read(cx).snapshot())
 698            .collect::<Vec<_>>();
 699        let debug_tx = self.debug_tx.clone();
 700
 701        let events = project_state
 702            .map(|state| {
 703                state
 704                    .events
 705                    .iter()
 706                    .filter_map(|event| event.to_request_event(cx))
 707                    .collect::<Vec<_>>()
 708            })
 709            .unwrap_or_default();
 710
 711        let diagnostics = active_snapshot.diagnostic_sets().clone();
 712
 713        let parent_abs_path =
 714            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 715                let mut path = f.worktree.read(cx).absolutize(&f.path);
 716                if path.pop() { Some(path) } else { None }
 717            });
 718
 719        // TODO data collection
 720        let can_collect_data = cx.is_staff();
 721
 722        let mut included_files = project_state
 723            .and_then(|project_state| project_state.context.as_ref())
 724            .unwrap_or(&HashMap::default())
 725            .iter()
 726            .filter_map(|(buffer_entity, ranges)| {
 727                let buffer = buffer_entity.read(cx);
 728                Some((
 729                    buffer_entity.clone(),
 730                    buffer.snapshot(),
 731                    buffer.file()?.full_path(cx).into(),
 732                    ranges.clone(),
 733                ))
 734            })
 735            .collect::<Vec<_>>();
 736
 737        let request_task = cx.background_spawn({
 738            let active_buffer = active_buffer.clone();
 739            async move {
 740                let index_state = if let Some(index_state) = index_state {
 741                    Some(index_state.lock_owned().await)
 742                } else {
 743                    None
 744                };
 745
 746                let cursor_offset = position.to_offset(&active_snapshot);
 747                let cursor_point = cursor_offset.to_point(&active_snapshot);
 748
 749                let before_retrieval = chrono::Utc::now();
 750
 751                let (diagnostic_groups, diagnostic_groups_truncated) =
 752                    Self::gather_nearby_diagnostics(
 753                        cursor_offset,
 754                        &diagnostics,
 755                        &active_snapshot,
 756                        options.max_diagnostic_bytes,
 757                    );
 758
 759                let cloud_request = match options.context {
 760                    ContextMode::Agentic(context_options) => {
 761                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 762                            cursor_point,
 763                            &active_snapshot,
 764                            &context_options.excerpt,
 765                            index_state.as_deref(),
 766                        ) else {
 767                            return Ok((None, None));
 768                        };
 769
 770                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 771                            ..active_snapshot.anchor_before(excerpt.range.end);
 772
 773                        if let Some(buffer_ix) =
 774                            included_files.iter().position(|(_, snapshot, _, _)| {
 775                                snapshot.remote_id() == active_snapshot.remote_id()
 776                            })
 777                        {
 778                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 779                            let range_ix = ranges
 780                                .binary_search_by(|probe| {
 781                                    probe
 782                                        .start
 783                                        .cmp(&excerpt_anchor_range.start, buffer)
 784                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 785                                })
 786                                .unwrap_or_else(|ix| ix);
 787
 788                            ranges.insert(range_ix, excerpt_anchor_range);
 789                            let last_ix = included_files.len() - 1;
 790                            included_files.swap(buffer_ix, last_ix);
 791                        } else {
 792                            included_files.push((
 793                                active_buffer.clone(),
 794                                active_snapshot,
 795                                excerpt_path.clone(),
 796                                vec![excerpt_anchor_range],
 797                            ));
 798                        }
 799
 800                        let included_files = included_files
 801                            .iter()
 802                            .map(|(_, buffer, path, ranges)| {
 803                                let excerpts = merge_excerpts(
 804                                    &buffer,
 805                                    ranges.iter().map(|range| {
 806                                        let point_range = range.to_point(&buffer);
 807                                        Line(point_range.start.row)..Line(point_range.end.row)
 808                                    }),
 809                                );
 810                                predict_edits_v3::IncludedFile {
 811                                    path: path.clone(),
 812                                    max_row: Line(buffer.max_point().row),
 813                                    excerpts,
 814                                }
 815                            })
 816                            .collect::<Vec<_>>();
 817
 818                        predict_edits_v3::PredictEditsRequest {
 819                            excerpt_path,
 820                            excerpt: String::new(),
 821                            excerpt_line_range: Line(0)..Line(0),
 822                            excerpt_range: 0..0,
 823                            cursor_point: predict_edits_v3::Point {
 824                                line: predict_edits_v3::Line(cursor_point.row),
 825                                column: cursor_point.column,
 826                            },
 827                            included_files,
 828                            referenced_declarations: vec![],
 829                            events,
 830                            can_collect_data,
 831                            diagnostic_groups,
 832                            diagnostic_groups_truncated,
 833                            debug_info: debug_tx.is_some(),
 834                            prompt_max_bytes: Some(options.max_prompt_bytes),
 835                            prompt_format: options.prompt_format,
 836                            // TODO [zeta2]
 837                            signatures: vec![],
 838                            excerpt_parent: None,
 839                            git_info: None,
 840                        }
 841                    }
 842                    ContextMode::Syntax(context_options) => {
 843                        let Some(context) = EditPredictionContext::gather_context(
 844                            cursor_point,
 845                            &active_snapshot,
 846                            parent_abs_path.as_deref(),
 847                            &context_options,
 848                            index_state.as_deref(),
 849                        ) else {
 850                            return Ok((None, None));
 851                        };
 852
 853                        make_syntax_context_cloud_request(
 854                            excerpt_path,
 855                            context,
 856                            events,
 857                            can_collect_data,
 858                            diagnostic_groups,
 859                            diagnostic_groups_truncated,
 860                            None,
 861                            debug_tx.is_some(),
 862                            &worktree_snapshots,
 863                            index_state.as_deref(),
 864                            Some(options.max_prompt_bytes),
 865                            options.prompt_format,
 866                        )
 867                    }
 868                };
 869
 870                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 871
 872                let retrieval_time = chrono::Utc::now() - before_retrieval;
 873
 874                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 875                    let (response_tx, response_rx) = oneshot::channel();
 876
 877                    debug_tx
 878                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 879                            ZetaEditPredictionDebugInfo {
 880                                request: cloud_request.clone(),
 881                                retrieval_time,
 882                                buffer: active_buffer.downgrade(),
 883                                local_prompt: match prompt_result.as_ref() {
 884                                    Ok((prompt, _)) => Ok(prompt.clone()),
 885                                    Err(err) => Err(err.to_string()),
 886                                },
 887                                position,
 888                                response_rx,
 889                            },
 890                        ))
 891                        .ok();
 892                    Some(response_tx)
 893                } else {
 894                    None
 895                };
 896
 897                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 898                    if let Some(debug_response_tx) = debug_response_tx {
 899                        debug_response_tx
 900                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 901                            .ok();
 902                    }
 903                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 904                }
 905
 906                let (prompt, _) = prompt_result?;
 907                let request = open_ai::Request {
 908                    model: MODEL_ID.clone(),
 909                    messages: vec![open_ai::RequestMessage::User {
 910                        content: open_ai::MessageContent::Plain(prompt),
 911                    }],
 912                    stream: false,
 913                    max_completion_tokens: None,
 914                    stop: Default::default(),
 915                    temperature: 0.7,
 916                    tool_choice: None,
 917                    parallel_tool_calls: None,
 918                    tools: vec![],
 919                    prompt_cache_key: None,
 920                    reasoning_effort: None,
 921                };
 922
 923                log::trace!("Sending edit prediction request");
 924
 925                let before_request = chrono::Utc::now();
 926                let response =
 927                    Self::send_raw_llm_request(client, llm_token, app_version, request).await;
 928                let request_time = chrono::Utc::now() - before_request;
 929
 930                log::trace!("Got edit prediction response");
 931
 932                if let Some(debug_response_tx) = debug_response_tx {
 933                    debug_response_tx
 934                        .send((
 935                            response
 936                                .as_ref()
 937                                .map_err(|err| err.to_string())
 938                                .map(|response| response.0.clone()),
 939                            request_time,
 940                        ))
 941                        .ok();
 942                }
 943
 944                let (res, usage) = response?;
 945                let request_id = EditPredictionId(res.id.clone().into());
 946                let Some(mut output_text) = text_from_response(res) else {
 947                    return Ok((None, usage))
 948                };
 949
 950                if output_text.contains(CURSOR_MARKER) {
 951                    log::trace!("Stripping out {CURSOR_MARKER} from response");
 952                    output_text = output_text.replace(CURSOR_MARKER, "");
 953                }
 954
 955                let (edited_buffer_snapshot, edits) = match options.prompt_format {
 956                    PromptFormat::NumLinesUniDiff => {
 957                        crate::udiff::parse_diff(&output_text, |path| {
 958                            included_files
 959                                .iter()
 960                                .find_map(|(_, buffer, probe_path, ranges)| {
 961                                    if probe_path.as_ref() == path {
 962                                        Some((buffer, ranges.as_slice()))
 963                                    } else {
 964                                        None
 965                                    }
 966                                })
 967                        })
 968                        .await?
 969                    }
 970                    _ => {
 971                        bail!("unsupported prompt format {}", options.prompt_format)
 972                    }
 973                };
 974
 975                let edited_buffer = included_files
 976                    .iter()
 977                    .find_map(|(buffer, snapshot, _, _)| {
 978                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
 979                            Some(buffer.clone())
 980                        } else {
 981                            None
 982                        }
 983                    })
 984                    .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
 985
 986                anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
 987            }
 988        });
 989
 990        cx.spawn({
 991            async move |this, cx| {
 992                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
 993                    Self::handle_api_response(&this, request_task.await, cx)?
 994                else {
 995                    return Ok(None);
 996                };
 997
 998                // TODO telemetry: duration, etc
 999                Ok(
1000                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1001                        .await,
1002                )
1003            }
1004        })
1005    }
1006
1007    async fn send_raw_llm_request(
1008        client: Arc<Client>,
1009        llm_token: LlmApiToken,
1010        app_version: SemanticVersion,
1011        request: open_ai::Request,
1012    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1013        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1014            http_client::Url::parse(&predict_edits_url)?
1015        } else {
1016            client
1017                .http_client()
1018                .build_zed_llm_url("/predict_edits/raw", &[])?
1019        };
1020
1021        Self::send_api_request(
1022            |builder| {
1023                let req = builder
1024                    .uri(url.as_ref())
1025                    .body(serde_json::to_string(&request)?.into());
1026                Ok(req?)
1027            },
1028            client,
1029            llm_token,
1030            app_version,
1031        )
1032        .await
1033    }
1034
1035    fn handle_api_response<T>(
1036        this: &WeakEntity<Self>,
1037        response: Result<(T, Option<EditPredictionUsage>)>,
1038        cx: &mut gpui::AsyncApp,
1039    ) -> Result<T> {
1040        match response {
1041            Ok((data, usage)) => {
1042                if let Some(usage) = usage {
1043                    this.update(cx, |this, cx| {
1044                        this.user_store.update(cx, |user_store, cx| {
1045                            user_store.update_edit_prediction_usage(usage, cx);
1046                        });
1047                    })
1048                    .ok();
1049                }
1050                Ok(data)
1051            }
1052            Err(err) => {
1053                if err.is::<ZedUpdateRequiredError>() {
1054                    cx.update(|cx| {
1055                        this.update(cx, |this, _cx| {
1056                            this.update_required = true;
1057                        })
1058                        .ok();
1059
1060                        let error_message: SharedString = err.to_string().into();
1061                        show_app_notification(
1062                            NotificationId::unique::<ZedUpdateRequiredError>(),
1063                            cx,
1064                            move |cx| {
1065                                cx.new(|cx| {
1066                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1067                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1068                                })
1069                            },
1070                        );
1071                    })
1072                    .ok();
1073                }
1074                Err(err)
1075            }
1076        }
1077    }
1078
1079    async fn send_api_request<Res>(
1080        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1081        client: Arc<Client>,
1082        llm_token: LlmApiToken,
1083        app_version: SemanticVersion,
1084    ) -> Result<(Res, Option<EditPredictionUsage>)>
1085    where
1086        Res: DeserializeOwned,
1087    {
1088        let http_client = client.http_client();
1089        let mut token = llm_token.acquire(&client).await?;
1090        let mut did_retry = false;
1091
1092        loop {
1093            let request_builder = http_client::Request::builder().method(Method::POST);
1094
1095            let request = build(
1096                request_builder
1097                    .header("Content-Type", "application/json")
1098                    .header("Authorization", format!("Bearer {}", token))
1099                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1100            )?;
1101
1102            let mut response = http_client.send(request).await?;
1103
1104            if let Some(minimum_required_version) = response
1105                .headers()
1106                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1107                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1108            {
1109                anyhow::ensure!(
1110                    app_version >= minimum_required_version,
1111                    ZedUpdateRequiredError {
1112                        minimum_version: minimum_required_version
1113                    }
1114                );
1115            }
1116
1117            if response.status().is_success() {
1118                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1119
1120                let mut body = Vec::new();
1121                response.body_mut().read_to_end(&mut body).await?;
1122                return Ok((serde_json::from_slice(&body)?, usage));
1123            } else if !did_retry
1124                && response
1125                    .headers()
1126                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1127                    .is_some()
1128            {
1129                did_retry = true;
1130                token = llm_token.refresh(&client).await?;
1131            } else {
1132                let mut body = String::new();
1133                response.body_mut().read_to_string(&mut body).await?;
1134                anyhow::bail!(
1135                    "Request failed with status: {:?}\nBody: {}",
1136                    response.status(),
1137                    body
1138                );
1139            }
1140        }
1141    }
1142
1143    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1144    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1145
1146    // Refresh the related excerpts when the user just beguns editing after
1147    // an idle period, and after they pause editing.
1148    fn refresh_context_if_needed(
1149        &mut self,
1150        project: &Entity<Project>,
1151        buffer: &Entity<language::Buffer>,
1152        cursor_position: language::Anchor,
1153        cx: &mut Context<Self>,
1154    ) {
1155        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1156            return;
1157        }
1158
1159        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1160            return;
1161        };
1162
1163        let now = Instant::now();
1164        let was_idle = zeta_project
1165            .refresh_context_timestamp
1166            .map_or(true, |timestamp| {
1167                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1168            });
1169        zeta_project.refresh_context_timestamp = Some(now);
1170        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1171            let buffer = buffer.clone();
1172            let project = project.clone();
1173            async move |this, cx| {
1174                if was_idle {
1175                    log::debug!("refetching edit prediction context after idle");
1176                } else {
1177                    cx.background_executor()
1178                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1179                        .await;
1180                    log::debug!("refetching edit prediction context after pause");
1181                }
1182                this.update(cx, |this, cx| {
1183                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1184
1185                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1186                        zeta_project.refresh_context_task = Some(task.log_err());
1187                    };
1188                })
1189                .ok()
1190            }
1191        }));
1192    }
1193
1194    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1195    // and avoid spawning more than one concurrent task.
1196    pub fn refresh_context(
1197        &mut self,
1198        project: Entity<Project>,
1199        buffer: Entity<language::Buffer>,
1200        cursor_position: language::Anchor,
1201        cx: &mut Context<Self>,
1202    ) -> Task<Result<()>> {
1203        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1204            return Task::ready(anyhow::Ok(()));
1205        };
1206
1207        let ContextMode::Agentic(options) = &self.options().context else {
1208            return Task::ready(anyhow::Ok(()));
1209        };
1210
1211        let snapshot = buffer.read(cx).snapshot();
1212        let cursor_point = cursor_position.to_point(&snapshot);
1213        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1214            cursor_point,
1215            &snapshot,
1216            &options.excerpt,
1217            None,
1218        ) else {
1219            return Task::ready(Ok(()));
1220        };
1221
1222        let app_version = AppVersion::global(cx);
1223        let client = self.client.clone();
1224        let llm_token = self.llm_token.clone();
1225        let debug_tx = self.debug_tx.clone();
1226        let current_file_path: Arc<Path> = snapshot
1227            .file()
1228            .map(|f| f.full_path(cx).into())
1229            .unwrap_or_else(|| Path::new("untitled").into());
1230
1231        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1232            predict_edits_v3::PlanContextRetrievalRequest {
1233                excerpt: cursor_excerpt.text(&snapshot).body,
1234                excerpt_path: current_file_path,
1235                excerpt_line_range: cursor_excerpt.line_range,
1236                cursor_file_max_row: Line(snapshot.max_point().row),
1237                events: zeta_project
1238                    .events
1239                    .iter()
1240                    .filter_map(|ev| ev.to_request_event(cx))
1241                    .collect(),
1242            },
1243        ) {
1244            Ok(prompt) => prompt,
1245            Err(err) => {
1246                return Task::ready(Err(err));
1247            }
1248        };
1249
1250        if let Some(debug_tx) = &debug_tx {
1251            debug_tx
1252                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1253                    ZetaContextRetrievalStartedDebugInfo {
1254                        project: project.clone(),
1255                        timestamp: Instant::now(),
1256                        search_prompt: prompt.clone(),
1257                    },
1258                ))
1259                .ok();
1260        }
1261
1262        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1263            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1264                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1265            );
1266
1267            let description = schema
1268                .get("description")
1269                .and_then(|description| description.as_str())
1270                .unwrap()
1271                .to_string();
1272
1273            (schema.into(), description)
1274        });
1275
1276        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1277
1278        let request = open_ai::Request {
1279            model: MODEL_ID.clone(),
1280            messages: vec![open_ai::RequestMessage::User {
1281                content: open_ai::MessageContent::Plain(prompt),
1282            }],
1283            stream: false,
1284            max_completion_tokens: None,
1285            stop: Default::default(),
1286            temperature: 0.7,
1287            tool_choice: None,
1288            parallel_tool_calls: None,
1289            tools: vec![open_ai::ToolDefinition::Function {
1290                function: FunctionDefinition {
1291                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1292                    description: Some(tool_description),
1293                    parameters: Some(tool_schema),
1294                },
1295            }],
1296            prompt_cache_key: None,
1297            reasoning_effort: None,
1298        };
1299
1300        cx.spawn(async move |this, cx| {
1301            log::trace!("Sending search planning request");
1302            let response =
1303                Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1304            let mut response = Self::handle_api_response(&this, response, cx)?;
1305            log::trace!("Got search planning response");
1306
1307            let choice = response
1308                .choices
1309                .pop()
1310                .context("No choices in retrieval response")?;
1311            let open_ai::RequestMessage::Assistant {
1312                content: _,
1313                tool_calls,
1314            } = choice.message
1315            else {
1316                anyhow::bail!("Retrieval response didn't include an assistant message");
1317            };
1318
1319            let mut queries: Vec<SearchToolQuery> = Vec::new();
1320            for tool_call in tool_calls {
1321                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1322                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1323                    log::warn!(
1324                        "Context retrieval response tried to call an unknown tool: {}",
1325                        function.name
1326                    );
1327
1328                    continue;
1329                }
1330
1331                let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1332                queries.extend(input.queries);
1333            }
1334
1335            if let Some(debug_tx) = &debug_tx {
1336                debug_tx
1337                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1338                        ZetaSearchQueryDebugInfo {
1339                            project: project.clone(),
1340                            timestamp: Instant::now(),
1341                            search_queries: queries.clone(),
1342                        },
1343                    ))
1344                    .ok();
1345            }
1346
1347            log::trace!("Running retrieval search: {queries:#?}");
1348
1349            let related_excerpts_result =
1350                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1351
1352            log::trace!("Search queries executed");
1353
1354            if let Some(debug_tx) = &debug_tx {
1355                debug_tx
1356                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1357                        ZetaContextRetrievalDebugInfo {
1358                            project: project.clone(),
1359                            timestamp: Instant::now(),
1360                        },
1361                    ))
1362                    .ok();
1363            }
1364
1365            this.update(cx, |this, _cx| {
1366                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1367                    return Ok(());
1368                };
1369                zeta_project.refresh_context_task.take();
1370                if let Some(debug_tx) = &this.debug_tx {
1371                    debug_tx
1372                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1373                            ZetaContextRetrievalDebugInfo {
1374                                project,
1375                                timestamp: Instant::now(),
1376                            },
1377                        ))
1378                        .ok();
1379                }
1380                match related_excerpts_result {
1381                    Ok(excerpts) => {
1382                        zeta_project.context = Some(excerpts);
1383                        Ok(())
1384                    }
1385                    Err(error) => Err(error),
1386                }
1387            })?
1388        })
1389    }
1390
1391    fn gather_nearby_diagnostics(
1392        cursor_offset: usize,
1393        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1394        snapshot: &BufferSnapshot,
1395        max_diagnostics_bytes: usize,
1396    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1397        // TODO: Could make this more efficient
1398        let mut diagnostic_groups = Vec::new();
1399        for (language_server_id, diagnostics) in diagnostic_sets {
1400            let mut groups = Vec::new();
1401            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1402            diagnostic_groups.extend(
1403                groups
1404                    .into_iter()
1405                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1406            );
1407        }
1408
1409        // sort by proximity to cursor
1410        diagnostic_groups.sort_by_key(|group| {
1411            let range = &group.entries[group.primary_ix].range;
1412            if range.start >= cursor_offset {
1413                range.start - cursor_offset
1414            } else if cursor_offset >= range.end {
1415                cursor_offset - range.end
1416            } else {
1417                (cursor_offset - range.start).min(range.end - cursor_offset)
1418            }
1419        });
1420
1421        let mut results = Vec::new();
1422        let mut diagnostic_groups_truncated = false;
1423        let mut diagnostics_byte_count = 0;
1424        for group in diagnostic_groups {
1425            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1426            diagnostics_byte_count += raw_value.get().len();
1427            if diagnostics_byte_count > max_diagnostics_bytes {
1428                diagnostic_groups_truncated = true;
1429                break;
1430            }
1431            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1432        }
1433
1434        (results, diagnostic_groups_truncated)
1435    }
1436
1437    // TODO: Dedupe with similar code in request_prediction?
1438    pub fn cloud_request_for_zeta_cli(
1439        &mut self,
1440        project: &Entity<Project>,
1441        buffer: &Entity<Buffer>,
1442        position: language::Anchor,
1443        cx: &mut Context<Self>,
1444    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1445        let project_state = self.projects.get(&project.entity_id());
1446
1447        let index_state = project_state.map(|state| {
1448            state
1449                .syntax_index
1450                .read_with(cx, |index, _cx| index.state().clone())
1451        });
1452        let options = self.options.clone();
1453        let snapshot = buffer.read(cx).snapshot();
1454        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1455            return Task::ready(Err(anyhow!("No file path for excerpt")));
1456        };
1457        let worktree_snapshots = project
1458            .read(cx)
1459            .worktrees(cx)
1460            .map(|worktree| worktree.read(cx).snapshot())
1461            .collect::<Vec<_>>();
1462
1463        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1464            let mut path = f.worktree.read(cx).absolutize(&f.path);
1465            if path.pop() { Some(path) } else { None }
1466        });
1467
1468        cx.background_spawn(async move {
1469            let index_state = if let Some(index_state) = index_state {
1470                Some(index_state.lock_owned().await)
1471            } else {
1472                None
1473            };
1474
1475            let cursor_point = position.to_point(&snapshot);
1476
1477            let debug_info = true;
1478            EditPredictionContext::gather_context(
1479                cursor_point,
1480                &snapshot,
1481                parent_abs_path.as_deref(),
1482                match &options.context {
1483                    ContextMode::Agentic(_) => {
1484                        // TODO
1485                        panic!("Llm mode not supported in zeta cli yet");
1486                    }
1487                    ContextMode::Syntax(edit_prediction_context_options) => {
1488                        edit_prediction_context_options
1489                    }
1490                },
1491                index_state.as_deref(),
1492            )
1493            .context("Failed to select excerpt")
1494            .map(|context| {
1495                make_syntax_context_cloud_request(
1496                    excerpt_path.into(),
1497                    context,
1498                    // TODO pass everything
1499                    Vec::new(),
1500                    false,
1501                    Vec::new(),
1502                    false,
1503                    None,
1504                    debug_info,
1505                    &worktree_snapshots,
1506                    index_state.as_deref(),
1507                    Some(options.max_prompt_bytes),
1508                    options.prompt_format,
1509                )
1510            })
1511        })
1512    }
1513
1514    pub fn wait_for_initial_indexing(
1515        &mut self,
1516        project: &Entity<Project>,
1517        cx: &mut App,
1518    ) -> Task<Result<()>> {
1519        let zeta_project = self.get_or_init_zeta_project(project, cx);
1520        zeta_project
1521            .syntax_index
1522            .read(cx)
1523            .wait_for_initial_file_indexing(cx)
1524    }
1525}
1526
1527pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1528    let choice = res.choices.pop()?;
1529    let output_text = match choice.message {
1530        open_ai::RequestMessage::Assistant {
1531            content: Some(open_ai::MessageContent::Plain(content)),
1532            ..
1533        } => content,
1534        open_ai::RequestMessage::Assistant {
1535            content: Some(open_ai::MessageContent::Multipart(mut content)),
1536            ..
1537        } => {
1538            if content.is_empty() {
1539                log::error!("No output from Baseten completion response");
1540                return None;
1541            }
1542
1543            match content.remove(0) {
1544                open_ai::MessagePart::Text { text } => text,
1545                open_ai::MessagePart::Image { .. } => {
1546                    log::error!("Expected text, got an image");
1547                    return None;
1548                }
1549            }
1550        }
1551        _ => {
1552            log::error!("Invalid response message: {:?}", choice.message);
1553            return None;
1554        }
1555    };
1556    Some(output_text)
1557}
1558
1559#[derive(Error, Debug)]
1560#[error(
1561    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1562)]
1563pub struct ZedUpdateRequiredError {
1564    minimum_version: SemanticVersion,
1565}
1566
1567fn make_syntax_context_cloud_request(
1568    excerpt_path: Arc<Path>,
1569    context: EditPredictionContext,
1570    events: Vec<predict_edits_v3::Event>,
1571    can_collect_data: bool,
1572    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1573    diagnostic_groups_truncated: bool,
1574    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1575    debug_info: bool,
1576    worktrees: &Vec<worktree::Snapshot>,
1577    index_state: Option<&SyntaxIndexState>,
1578    prompt_max_bytes: Option<usize>,
1579    prompt_format: PromptFormat,
1580) -> predict_edits_v3::PredictEditsRequest {
1581    let mut signatures = Vec::new();
1582    let mut declaration_to_signature_index = HashMap::default();
1583    let mut referenced_declarations = Vec::new();
1584
1585    for snippet in context.declarations {
1586        let project_entry_id = snippet.declaration.project_entry_id();
1587        let Some(path) = worktrees.iter().find_map(|worktree| {
1588            worktree.entry_for_id(project_entry_id).map(|entry| {
1589                let mut full_path = RelPathBuf::new();
1590                full_path.push(worktree.root_name());
1591                full_path.push(&entry.path);
1592                full_path
1593            })
1594        }) else {
1595            continue;
1596        };
1597
1598        let parent_index = index_state.and_then(|index_state| {
1599            snippet.declaration.parent().and_then(|parent| {
1600                add_signature(
1601                    parent,
1602                    &mut declaration_to_signature_index,
1603                    &mut signatures,
1604                    index_state,
1605                )
1606            })
1607        });
1608
1609        let (text, text_is_truncated) = snippet.declaration.item_text();
1610        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1611            path: path.as_std_path().into(),
1612            text: text.into(),
1613            range: snippet.declaration.item_line_range(),
1614            text_is_truncated,
1615            signature_range: snippet.declaration.signature_range_in_item_text(),
1616            parent_index,
1617            signature_score: snippet.score(DeclarationStyle::Signature),
1618            declaration_score: snippet.score(DeclarationStyle::Declaration),
1619            score_components: snippet.components,
1620        });
1621    }
1622
1623    let excerpt_parent = index_state.and_then(|index_state| {
1624        context
1625            .excerpt
1626            .parent_declarations
1627            .last()
1628            .and_then(|(parent, _)| {
1629                add_signature(
1630                    *parent,
1631                    &mut declaration_to_signature_index,
1632                    &mut signatures,
1633                    index_state,
1634                )
1635            })
1636    });
1637
1638    predict_edits_v3::PredictEditsRequest {
1639        excerpt_path,
1640        excerpt: context.excerpt_text.body,
1641        excerpt_line_range: context.excerpt.line_range,
1642        excerpt_range: context.excerpt.range,
1643        cursor_point: predict_edits_v3::Point {
1644            line: predict_edits_v3::Line(context.cursor_point.row),
1645            column: context.cursor_point.column,
1646        },
1647        referenced_declarations,
1648        included_files: vec![],
1649        signatures,
1650        excerpt_parent,
1651        events,
1652        can_collect_data,
1653        diagnostic_groups,
1654        diagnostic_groups_truncated,
1655        git_info,
1656        debug_info,
1657        prompt_max_bytes,
1658        prompt_format,
1659    }
1660}
1661
1662fn add_signature(
1663    declaration_id: DeclarationId,
1664    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1665    signatures: &mut Vec<Signature>,
1666    index: &SyntaxIndexState,
1667) -> Option<usize> {
1668    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1669        return Some(*signature_index);
1670    }
1671    let Some(parent_declaration) = index.declaration(declaration_id) else {
1672        log::error!("bug: missing parent declaration");
1673        return None;
1674    };
1675    let parent_index = parent_declaration.parent().and_then(|parent| {
1676        add_signature(parent, declaration_to_signature_index, signatures, index)
1677    });
1678    let (text, text_is_truncated) = parent_declaration.signature_text();
1679    let signature_index = signatures.len();
1680    signatures.push(Signature {
1681        text: text.into(),
1682        text_is_truncated,
1683        parent_index,
1684        range: parent_declaration.signature_line_range(),
1685    });
1686    declaration_to_signature_index.insert(declaration_id, signature_index);
1687    Some(signature_index)
1688}
1689
1690#[cfg(test)]
1691mod tests {
1692    use std::{path::Path, sync::Arc};
1693
1694    use client::UserStore;
1695    use clock::FakeSystemClock;
1696    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1697    use futures::{
1698        AsyncReadExt, StreamExt,
1699        channel::{mpsc, oneshot},
1700    };
1701    use gpui::{
1702        Entity, TestAppContext,
1703        http_client::{FakeHttpClient, Response},
1704        prelude::*,
1705    };
1706    use indoc::indoc;
1707    use language::OffsetRangeExt as _;
1708    use open_ai::Usage;
1709    use pretty_assertions::{assert_eq, assert_matches};
1710    use project::{FakeFs, Project};
1711    use serde_json::json;
1712    use settings::SettingsStore;
1713    use util::path;
1714    use uuid::Uuid;
1715
1716    use crate::{BufferEditPrediction, Zeta};
1717
1718    #[gpui::test]
1719    async fn test_current_state(cx: &mut TestAppContext) {
1720        let (zeta, mut req_rx) = init_test(cx);
1721        let fs = FakeFs::new(cx.executor());
1722        fs.insert_tree(
1723            "/root",
1724            json!({
1725                "1.txt": "Hello!\nHow\nBye\n",
1726                "2.txt": "Hola!\nComo\nAdios\n"
1727            }),
1728        )
1729        .await;
1730        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1731
1732        zeta.update(cx, |zeta, cx| {
1733            zeta.register_project(&project, cx);
1734        });
1735
1736        let buffer1 = project
1737            .update(cx, |project, cx| {
1738                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1739                project.open_buffer(path, cx)
1740            })
1741            .await
1742            .unwrap();
1743        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1744        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1745
1746        // Prediction for current file
1747
1748        let prediction_task = zeta.update(cx, |zeta, cx| {
1749            zeta.refresh_prediction(&project, &buffer1, position, cx)
1750        });
1751        let (_request, respond_tx) = req_rx.next().await.unwrap();
1752
1753        respond_tx
1754            .send(model_response(indoc! {r"
1755                --- a/root/1.txt
1756                +++ b/root/1.txt
1757                @@ ... @@
1758                 Hello!
1759                -How
1760                +How are you?
1761                 Bye
1762            "}))
1763            .unwrap();
1764        prediction_task.await.unwrap();
1765
1766        zeta.read_with(cx, |zeta, cx| {
1767            let prediction = zeta
1768                .current_prediction_for_buffer(&buffer1, &project, cx)
1769                .unwrap();
1770            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1771        });
1772
1773        // Context refresh
1774        let refresh_task = zeta.update(cx, |zeta, cx| {
1775            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1776        });
1777        let (_request, respond_tx) = req_rx.next().await.unwrap();
1778        respond_tx
1779            .send(open_ai::Response {
1780                id: Uuid::new_v4().to_string(),
1781                object: "response".into(),
1782                created: 0,
1783                model: "model".into(),
1784                choices: vec![open_ai::Choice {
1785                    index: 0,
1786                    message: open_ai::RequestMessage::Assistant {
1787                        content: None,
1788                        tool_calls: vec![open_ai::ToolCall {
1789                            id: "search".into(),
1790                            content: open_ai::ToolCallContent::Function {
1791                                function: open_ai::FunctionContent {
1792                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1793                                        .to_string(),
1794                                    arguments: serde_json::to_string(&SearchToolInput {
1795                                        queries: Box::new([SearchToolQuery {
1796                                            glob: "root/2.txt".to_string(),
1797                                            syntax_node: vec![],
1798                                            content: Some(".".into()),
1799                                        }]),
1800                                    })
1801                                    .unwrap(),
1802                                },
1803                            },
1804                        }],
1805                    },
1806                    finish_reason: None,
1807                }],
1808                usage: Usage {
1809                    prompt_tokens: 0,
1810                    completion_tokens: 0,
1811                    total_tokens: 0,
1812                },
1813            })
1814            .unwrap();
1815        refresh_task.await.unwrap();
1816
1817        zeta.update(cx, |zeta, _cx| {
1818            zeta.discard_current_prediction(&project);
1819        });
1820
1821        // Prediction for another file
1822        let prediction_task = zeta.update(cx, |zeta, cx| {
1823            zeta.refresh_prediction(&project, &buffer1, position, cx)
1824        });
1825        let (_request, respond_tx) = req_rx.next().await.unwrap();
1826        respond_tx
1827            .send(model_response(indoc! {r#"
1828                --- a/root/2.txt
1829                +++ b/root/2.txt
1830                 Hola!
1831                -Como
1832                +Como estas?
1833                 Adios
1834            "#}))
1835            .unwrap();
1836        prediction_task.await.unwrap();
1837        zeta.read_with(cx, |zeta, cx| {
1838            let prediction = zeta
1839                .current_prediction_for_buffer(&buffer1, &project, cx)
1840                .unwrap();
1841            assert_matches!(
1842                prediction,
1843                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1844            );
1845        });
1846
1847        let buffer2 = project
1848            .update(cx, |project, cx| {
1849                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1850                project.open_buffer(path, cx)
1851            })
1852            .await
1853            .unwrap();
1854
1855        zeta.read_with(cx, |zeta, cx| {
1856            let prediction = zeta
1857                .current_prediction_for_buffer(&buffer2, &project, cx)
1858                .unwrap();
1859            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1860        });
1861    }
1862
1863    #[gpui::test]
1864    async fn test_simple_request(cx: &mut TestAppContext) {
1865        let (zeta, mut req_rx) = init_test(cx);
1866        let fs = FakeFs::new(cx.executor());
1867        fs.insert_tree(
1868            "/root",
1869            json!({
1870                "foo.md":  "Hello!\nHow\nBye\n"
1871            }),
1872        )
1873        .await;
1874        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1875
1876        let buffer = project
1877            .update(cx, |project, cx| {
1878                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1879                project.open_buffer(path, cx)
1880            })
1881            .await
1882            .unwrap();
1883        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1884        let position = snapshot.anchor_before(language::Point::new(1, 3));
1885
1886        let prediction_task = zeta.update(cx, |zeta, cx| {
1887            zeta.request_prediction(&project, &buffer, position, cx)
1888        });
1889
1890        let (_, respond_tx) = req_rx.next().await.unwrap();
1891
1892        // TODO Put back when we have a structured request again
1893        // assert_eq!(
1894        //     request.excerpt_path.as_ref(),
1895        //     Path::new(path!("root/foo.md"))
1896        // );
1897        // assert_eq!(
1898        //     request.cursor_point,
1899        //     Point {
1900        //         line: Line(1),
1901        //         column: 3
1902        //     }
1903        // );
1904
1905        respond_tx
1906            .send(model_response(indoc! { r"
1907                --- a/root/foo.md
1908                +++ b/root/foo.md
1909                @@ ... @@
1910                 Hello!
1911                -How
1912                +How are you?
1913                 Bye
1914            "}))
1915            .unwrap();
1916
1917        let prediction = prediction_task.await.unwrap().unwrap();
1918
1919        assert_eq!(prediction.edits.len(), 1);
1920        assert_eq!(
1921            prediction.edits[0].0.to_point(&snapshot).start,
1922            language::Point::new(1, 3)
1923        );
1924        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1925    }
1926
1927    #[gpui::test]
1928    async fn test_request_events(cx: &mut TestAppContext) {
1929        let (zeta, mut req_rx) = init_test(cx);
1930        let fs = FakeFs::new(cx.executor());
1931        fs.insert_tree(
1932            "/root",
1933            json!({
1934                "foo.md": "Hello!\n\nBye\n"
1935            }),
1936        )
1937        .await;
1938        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1939
1940        let buffer = project
1941            .update(cx, |project, cx| {
1942                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1943                project.open_buffer(path, cx)
1944            })
1945            .await
1946            .unwrap();
1947
1948        zeta.update(cx, |zeta, cx| {
1949            zeta.register_buffer(&buffer, &project, cx);
1950        });
1951
1952        buffer.update(cx, |buffer, cx| {
1953            buffer.edit(vec![(7..7, "How")], None, cx);
1954        });
1955
1956        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1957        let position = snapshot.anchor_before(language::Point::new(1, 3));
1958
1959        let prediction_task = zeta.update(cx, |zeta, cx| {
1960            zeta.request_prediction(&project, &buffer, position, cx)
1961        });
1962
1963        let (request, respond_tx) = req_rx.next().await.unwrap();
1964
1965        let prompt = prompt_from_request(&request);
1966        assert!(
1967            prompt.contains(indoc! {"
1968            --- a/root/foo.md
1969            +++ b/root/foo.md
1970            @@ -1,3 +1,3 @@
1971             Hello!
1972            -
1973            +How
1974             Bye
1975        "}),
1976            "{prompt}"
1977        );
1978
1979        respond_tx
1980            .send(model_response(indoc! {r#"
1981                --- a/root/foo.md
1982                +++ b/root/foo.md
1983                @@ ... @@
1984                 Hello!
1985                -How
1986                +How are you?
1987                 Bye
1988            "#}))
1989            .unwrap();
1990
1991        let prediction = prediction_task.await.unwrap().unwrap();
1992
1993        assert_eq!(prediction.edits.len(), 1);
1994        assert_eq!(
1995            prediction.edits[0].0.to_point(&snapshot).start,
1996            language::Point::new(1, 3)
1997        );
1998        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1999    }
2000
2001    // Skipped until we start including diagnostics in prompt
2002    // #[gpui::test]
2003    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2004    //     let (zeta, mut req_rx) = init_test(cx);
2005    //     let fs = FakeFs::new(cx.executor());
2006    //     fs.insert_tree(
2007    //         "/root",
2008    //         json!({
2009    //             "foo.md": "Hello!\nBye"
2010    //         }),
2011    //     )
2012    //     .await;
2013    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2014
2015    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2016    //     let diagnostic = lsp::Diagnostic {
2017    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2018    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2019    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2020    //         ..Default::default()
2021    //     };
2022
2023    //     project.update(cx, |project, cx| {
2024    //         project.lsp_store().update(cx, |lsp_store, cx| {
2025    //             // Create some diagnostics
2026    //             lsp_store
2027    //                 .update_diagnostics(
2028    //                     LanguageServerId(0),
2029    //                     lsp::PublishDiagnosticsParams {
2030    //                         uri: path_to_buffer_uri.clone(),
2031    //                         diagnostics: vec![diagnostic],
2032    //                         version: None,
2033    //                     },
2034    //                     None,
2035    //                     language::DiagnosticSourceKind::Pushed,
2036    //                     &[],
2037    //                     cx,
2038    //                 )
2039    //                 .unwrap();
2040    //         });
2041    //     });
2042
2043    //     let buffer = project
2044    //         .update(cx, |project, cx| {
2045    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2046    //             project.open_buffer(path, cx)
2047    //         })
2048    //         .await
2049    //         .unwrap();
2050
2051    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2052    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2053
2054    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2055    //         zeta.request_prediction(&project, &buffer, position, cx)
2056    //     });
2057
2058    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2059
2060    //     assert_eq!(request.diagnostic_groups.len(), 1);
2061    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2062    //         .unwrap();
2063    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2064    //     assert_eq!(
2065    //         value,
2066    //         json!({
2067    //             "entries": [{
2068    //                 "range": {
2069    //                     "start": 8,
2070    //                     "end": 10
2071    //                 },
2072    //                 "diagnostic": {
2073    //                     "source": null,
2074    //                     "code": null,
2075    //                     "code_description": null,
2076    //                     "severity": 1,
2077    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2078    //                     "markdown": null,
2079    //                     "group_id": 0,
2080    //                     "is_primary": true,
2081    //                     "is_disk_based": false,
2082    //                     "is_unnecessary": false,
2083    //                     "source_kind": "Pushed",
2084    //                     "data": null,
2085    //                     "underline": true
2086    //                 }
2087    //             }],
2088    //             "primary_ix": 0
2089    //         })
2090    //     );
2091    // }
2092
2093    fn model_response(text: &str) -> open_ai::Response {
2094        open_ai::Response {
2095            id: Uuid::new_v4().to_string(),
2096            object: "response".into(),
2097            created: 0,
2098            model: "model".into(),
2099            choices: vec![open_ai::Choice {
2100                index: 0,
2101                message: open_ai::RequestMessage::Assistant {
2102                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2103                    tool_calls: vec![],
2104                },
2105                finish_reason: None,
2106            }],
2107            usage: Usage {
2108                prompt_tokens: 0,
2109                completion_tokens: 0,
2110                total_tokens: 0,
2111            },
2112        }
2113    }
2114
2115    fn prompt_from_request(request: &open_ai::Request) -> &str {
2116        assert_eq!(request.messages.len(), 1);
2117        let open_ai::RequestMessage::User {
2118            content: open_ai::MessageContent::Plain(content),
2119            ..
2120        } = &request.messages[0]
2121        else {
2122            panic!(
2123                "Request does not have single user message of type Plain. {:#?}",
2124                request
2125            );
2126        };
2127        content
2128    }
2129
2130    fn init_test(
2131        cx: &mut TestAppContext,
2132    ) -> (
2133        Entity<Zeta>,
2134        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2135    ) {
2136        cx.update(move |cx| {
2137            let settings_store = SettingsStore::test(cx);
2138            cx.set_global(settings_store);
2139            zlog::init_test();
2140
2141            let (req_tx, req_rx) = mpsc::unbounded();
2142
2143            let http_client = FakeHttpClient::create({
2144                move |req| {
2145                    let uri = req.uri().path().to_string();
2146                    let mut body = req.into_body();
2147                    let req_tx = req_tx.clone();
2148                    async move {
2149                        let resp = match uri.as_str() {
2150                            "/client/llm_tokens" => serde_json::to_string(&json!({
2151                                "token": "test"
2152                            }))
2153                            .unwrap(),
2154                            "/predict_edits/raw" => {
2155                                let mut buf = Vec::new();
2156                                body.read_to_end(&mut buf).await.ok();
2157                                let req = serde_json::from_slice(&buf).unwrap();
2158
2159                                let (res_tx, res_rx) = oneshot::channel();
2160                                req_tx.unbounded_send((req, res_tx)).unwrap();
2161                                serde_json::to_string(&res_rx.await?).unwrap()
2162                            }
2163                            _ => {
2164                                panic!("Unexpected path: {}", uri)
2165                            }
2166                        };
2167
2168                        Ok(Response::builder().body(resp.into()).unwrap())
2169                    }
2170                }
2171            });
2172
2173            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2174            client.cloud_client().set_credentials(1, "test".into());
2175
2176            language_model::init(client.clone(), cx);
2177
2178            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2179            let zeta = Zeta::global(&client, &user_store, cx);
2180
2181            (zeta, req_rx)
2182        })
2183    }
2184}