zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  10use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput;
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33use uuid::Uuid;
  34
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::Arc;
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50
  51use crate::merge_excerpts::merge_excerpts;
  52use crate::prediction::EditPrediction;
  53pub use crate::prediction::EditPredictionId;
  54pub use provider::ZetaEditPredictionProvider;
  55
  56/// Maximum number of events to track.
  57const MAX_EVENT_COUNT: usize = 16;
  58
  59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  60    max_bytes: 512,
  61    min_bytes: 128,
  62    target_before_cursor_over_total_bytes: 0.5,
  63};
  64
  65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  66    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  67
  68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  69    excerpt: DEFAULT_EXCERPT_OPTIONS,
  70};
  71
  72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  73    EditPredictionContextOptions {
  74        use_imports: true,
  75        max_retrieved_declarations: 0,
  76        excerpt: DEFAULT_EXCERPT_OPTIONS,
  77        score: EditPredictionScoreOptions {
  78            omit_excerpt_overlaps: true,
  79        },
  80    };
  81
  82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  83    context: DEFAULT_CONTEXT_OPTIONS,
  84    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  85    max_diagnostic_bytes: 2048,
  86    prompt_format: PromptFormat::DEFAULT,
  87    file_indexing_parallelism: 1,
  88    buffer_change_grouping_interval: Duration::from_secs(1),
  89};
  90
  91pub struct Zeta2FeatureFlag;
  92
  93impl FeatureFlag for Zeta2FeatureFlag {
  94    const NAME: &'static str = "zeta2";
  95
  96    fn enabled_for_staff() -> bool {
  97        false
  98    }
  99}
 100
 101#[derive(Clone)]
 102struct ZetaGlobal(Entity<Zeta>);
 103
 104impl Global for ZetaGlobal {}
 105
 106pub struct Zeta {
 107    client: Arc<Client>,
 108    user_store: Entity<UserStore>,
 109    llm_token: LlmApiToken,
 110    _llm_token_subscription: Subscription,
 111    projects: HashMap<EntityId, ZetaProject>,
 112    options: ZetaOptions,
 113    update_required: bool,
 114    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 115}
 116
 117#[derive(Debug, Clone, PartialEq)]
 118pub struct ZetaOptions {
 119    pub context: ContextMode,
 120    pub max_prompt_bytes: usize,
 121    pub max_diagnostic_bytes: usize,
 122    pub prompt_format: predict_edits_v3::PromptFormat,
 123    pub file_indexing_parallelism: usize,
 124    pub buffer_change_grouping_interval: Duration,
 125}
 126
 127#[derive(Debug, Clone, PartialEq)]
 128pub enum ContextMode {
 129    Agentic(AgenticContextOptions),
 130    Syntax(EditPredictionContextOptions),
 131}
 132
 133#[derive(Debug, Clone, PartialEq)]
 134pub struct AgenticContextOptions {
 135    pub excerpt: EditPredictionExcerptOptions,
 136}
 137
 138impl ContextMode {
 139    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 140        match self {
 141            ContextMode::Agentic(options) => &options.excerpt,
 142            ContextMode::Syntax(options) => &options.excerpt,
 143        }
 144    }
 145}
 146
 147#[derive(Debug)]
 148pub enum ZetaDebugInfo {
 149    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 150    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 151    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 152    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 153    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 154}
 155
 156#[derive(Debug)]
 157pub struct ZetaContextRetrievalStartedDebugInfo {
 158    pub project: Entity<Project>,
 159    pub timestamp: Instant,
 160    pub search_prompt: String,
 161}
 162
 163#[derive(Debug)]
 164pub struct ZetaContextRetrievalDebugInfo {
 165    pub project: Entity<Project>,
 166    pub timestamp: Instant,
 167}
 168
 169#[derive(Debug)]
 170pub struct ZetaEditPredictionDebugInfo {
 171    pub request: predict_edits_v3::PredictEditsRequest,
 172    pub retrieval_time: TimeDelta,
 173    pub buffer: WeakEntity<Buffer>,
 174    pub position: language::Anchor,
 175    pub local_prompt: Result<String, String>,
 176    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 177}
 178
 179#[derive(Debug)]
 180pub struct ZetaSearchQueryDebugInfo {
 181    pub project: Entity<Project>,
 182    pub timestamp: Instant,
 183    pub regex_by_glob: HashMap<String, String>,
 184}
 185
 186pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 187
 188struct ZetaProject {
 189    syntax_index: Entity<SyntaxIndex>,
 190    events: VecDeque<Event>,
 191    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 192    current_prediction: Option<CurrentEditPrediction>,
 193    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 194    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 195    refresh_context_debounce_task: Option<Task<Option<()>>>,
 196    refresh_context_timestamp: Option<Instant>,
 197}
 198
 199#[derive(Debug, Clone)]
 200struct CurrentEditPrediction {
 201    pub requested_by_buffer_id: EntityId,
 202    pub prediction: EditPrediction,
 203}
 204
 205impl CurrentEditPrediction {
 206    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 207        let Some(new_edits) = self
 208            .prediction
 209            .interpolate(&self.prediction.buffer.read(cx))
 210        else {
 211            return false;
 212        };
 213
 214        if self.prediction.buffer != old_prediction.prediction.buffer {
 215            return true;
 216        }
 217
 218        let Some(old_edits) = old_prediction
 219            .prediction
 220            .interpolate(&old_prediction.prediction.buffer.read(cx))
 221        else {
 222            return true;
 223        };
 224
 225        // This reduces the occurrence of UI thrash from replacing edits
 226        //
 227        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 228        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 229            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 230            && old_edits.len() == 1
 231            && new_edits.len() == 1
 232        {
 233            let (old_range, old_text) = &old_edits[0];
 234            let (new_range, new_text) = &new_edits[0];
 235            new_range == old_range && new_text.starts_with(old_text.as_ref())
 236        } else {
 237            true
 238        }
 239    }
 240}
 241
 242/// A prediction from the perspective of a buffer.
 243#[derive(Debug)]
 244enum BufferEditPrediction<'a> {
 245    Local { prediction: &'a EditPrediction },
 246    Jump { prediction: &'a EditPrediction },
 247}
 248
 249struct RegisteredBuffer {
 250    snapshot: BufferSnapshot,
 251    _subscriptions: [gpui::Subscription; 2],
 252}
 253
 254#[derive(Clone)]
 255pub enum Event {
 256    BufferChange {
 257        old_snapshot: BufferSnapshot,
 258        new_snapshot: BufferSnapshot,
 259        timestamp: Instant,
 260    },
 261}
 262
 263impl Event {
 264    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 265        match self {
 266            Event::BufferChange {
 267                old_snapshot,
 268                new_snapshot,
 269                ..
 270            } => {
 271                let path = new_snapshot.file().map(|f| f.full_path(cx));
 272
 273                let old_path = old_snapshot.file().and_then(|f| {
 274                    let old_path = f.full_path(cx);
 275                    if Some(&old_path) != path.as_ref() {
 276                        Some(old_path)
 277                    } else {
 278                        None
 279                    }
 280                });
 281
 282                // TODO [zeta2] move to bg?
 283                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 284
 285                if path == old_path && diff.is_empty() {
 286                    None
 287                } else {
 288                    Some(predict_edits_v3::Event::BufferChange {
 289                        old_path,
 290                        path,
 291                        diff,
 292                        //todo: Actually detect if this edit was predicted or not
 293                        predicted: false,
 294                    })
 295                }
 296            }
 297        }
 298    }
 299}
 300
 301impl Zeta {
 302    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 303        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 304    }
 305
 306    pub fn global(
 307        client: &Arc<Client>,
 308        user_store: &Entity<UserStore>,
 309        cx: &mut App,
 310    ) -> Entity<Self> {
 311        cx.try_global::<ZetaGlobal>()
 312            .map(|global| global.0.clone())
 313            .unwrap_or_else(|| {
 314                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 315                cx.set_global(ZetaGlobal(zeta.clone()));
 316                zeta
 317            })
 318    }
 319
 320    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 321        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 322
 323        Self {
 324            projects: HashMap::default(),
 325            client,
 326            user_store,
 327            options: DEFAULT_OPTIONS,
 328            llm_token: LlmApiToken::default(),
 329            _llm_token_subscription: cx.subscribe(
 330                &refresh_llm_token_listener,
 331                |this, _listener, _event, cx| {
 332                    let client = this.client.clone();
 333                    let llm_token = this.llm_token.clone();
 334                    cx.spawn(async move |_this, _cx| {
 335                        llm_token.refresh(&client).await?;
 336                        anyhow::Ok(())
 337                    })
 338                    .detach_and_log_err(cx);
 339                },
 340            ),
 341            update_required: false,
 342            debug_tx: None,
 343        }
 344    }
 345
 346    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 347        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 348        self.debug_tx = Some(debug_watch_tx);
 349        debug_watch_rx
 350    }
 351
 352    pub fn options(&self) -> &ZetaOptions {
 353        &self.options
 354    }
 355
 356    pub fn set_options(&mut self, options: ZetaOptions) {
 357        self.options = options;
 358    }
 359
 360    pub fn clear_history(&mut self) {
 361        for zeta_project in self.projects.values_mut() {
 362            zeta_project.events.clear();
 363        }
 364    }
 365
 366    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 367        self.projects
 368            .get(&project.entity_id())
 369            .map(|project| project.events.iter())
 370            .into_iter()
 371            .flatten()
 372    }
 373
 374    pub fn context_for_project(
 375        &self,
 376        project: &Entity<Project>,
 377    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 378        self.projects
 379            .get(&project.entity_id())
 380            .and_then(|project| {
 381                Some(
 382                    project
 383                        .context
 384                        .as_ref()?
 385                        .iter()
 386                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 387                )
 388            })
 389            .into_iter()
 390            .flatten()
 391    }
 392
 393    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 394        self.user_store.read(cx).edit_prediction_usage()
 395    }
 396
 397    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 398        self.get_or_init_zeta_project(project, cx);
 399    }
 400
 401    pub fn register_buffer(
 402        &mut self,
 403        buffer: &Entity<Buffer>,
 404        project: &Entity<Project>,
 405        cx: &mut Context<Self>,
 406    ) {
 407        let zeta_project = self.get_or_init_zeta_project(project, cx);
 408        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 409    }
 410
 411    fn get_or_init_zeta_project(
 412        &mut self,
 413        project: &Entity<Project>,
 414        cx: &mut App,
 415    ) -> &mut ZetaProject {
 416        self.projects
 417            .entry(project.entity_id())
 418            .or_insert_with(|| ZetaProject {
 419                syntax_index: cx.new(|cx| {
 420                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 421                }),
 422                events: VecDeque::new(),
 423                registered_buffers: HashMap::default(),
 424                current_prediction: None,
 425                context: None,
 426                refresh_context_task: None,
 427                refresh_context_debounce_task: None,
 428                refresh_context_timestamp: None,
 429            })
 430    }
 431
 432    fn register_buffer_impl<'a>(
 433        zeta_project: &'a mut ZetaProject,
 434        buffer: &Entity<Buffer>,
 435        project: &Entity<Project>,
 436        cx: &mut Context<Self>,
 437    ) -> &'a mut RegisteredBuffer {
 438        let buffer_id = buffer.entity_id();
 439        match zeta_project.registered_buffers.entry(buffer_id) {
 440            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 441            hash_map::Entry::Vacant(entry) => {
 442                let snapshot = buffer.read(cx).snapshot();
 443                let project_entity_id = project.entity_id();
 444                entry.insert(RegisteredBuffer {
 445                    snapshot,
 446                    _subscriptions: [
 447                        cx.subscribe(buffer, {
 448                            let project = project.downgrade();
 449                            move |this, buffer, event, cx| {
 450                                if let language::BufferEvent::Edited = event
 451                                    && let Some(project) = project.upgrade()
 452                                {
 453                                    this.report_changes_for_buffer(&buffer, &project, cx);
 454                                }
 455                            }
 456                        }),
 457                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 458                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 459                            else {
 460                                return;
 461                            };
 462                            zeta_project.registered_buffers.remove(&buffer_id);
 463                        }),
 464                    ],
 465                })
 466            }
 467        }
 468    }
 469
 470    fn report_changes_for_buffer(
 471        &mut self,
 472        buffer: &Entity<Buffer>,
 473        project: &Entity<Project>,
 474        cx: &mut Context<Self>,
 475    ) -> BufferSnapshot {
 476        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 477        let zeta_project = self.get_or_init_zeta_project(project, cx);
 478        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 479
 480        let new_snapshot = buffer.read(cx).snapshot();
 481        if new_snapshot.version != registered_buffer.snapshot.version {
 482            let old_snapshot =
 483                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 484            Self::push_event(
 485                zeta_project,
 486                buffer_change_grouping_interval,
 487                Event::BufferChange {
 488                    old_snapshot,
 489                    new_snapshot: new_snapshot.clone(),
 490                    timestamp: Instant::now(),
 491                },
 492            );
 493        }
 494
 495        new_snapshot
 496    }
 497
 498    fn push_event(
 499        zeta_project: &mut ZetaProject,
 500        buffer_change_grouping_interval: Duration,
 501        event: Event,
 502    ) {
 503        let events = &mut zeta_project.events;
 504
 505        if buffer_change_grouping_interval > Duration::ZERO
 506            && let Some(Event::BufferChange {
 507                new_snapshot: last_new_snapshot,
 508                timestamp: last_timestamp,
 509                ..
 510            }) = events.back_mut()
 511        {
 512            // Coalesce edits for the same buffer when they happen one after the other.
 513            let Event::BufferChange {
 514                old_snapshot,
 515                new_snapshot,
 516                timestamp,
 517            } = &event;
 518
 519            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 520                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 521                && old_snapshot.version == last_new_snapshot.version
 522            {
 523                *last_new_snapshot = new_snapshot.clone();
 524                *last_timestamp = *timestamp;
 525                return;
 526            }
 527        }
 528
 529        if events.len() >= MAX_EVENT_COUNT {
 530            // These are halved instead of popping to improve prompt caching.
 531            events.drain(..MAX_EVENT_COUNT / 2);
 532        }
 533
 534        events.push_back(event);
 535    }
 536
 537    fn current_prediction_for_buffer(
 538        &self,
 539        buffer: &Entity<Buffer>,
 540        project: &Entity<Project>,
 541        cx: &App,
 542    ) -> Option<BufferEditPrediction<'_>> {
 543        let project_state = self.projects.get(&project.entity_id())?;
 544
 545        let CurrentEditPrediction {
 546            requested_by_buffer_id,
 547            prediction,
 548        } = project_state.current_prediction.as_ref()?;
 549
 550        if prediction.targets_buffer(buffer.read(cx)) {
 551            Some(BufferEditPrediction::Local { prediction })
 552        } else if *requested_by_buffer_id == buffer.entity_id() {
 553            Some(BufferEditPrediction::Jump { prediction })
 554        } else {
 555            None
 556        }
 557    }
 558
 559    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 560        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 561            return;
 562        };
 563
 564        let Some(prediction) = project_state.current_prediction.take() else {
 565            return;
 566        };
 567        let request_id = prediction.prediction.id.into();
 568
 569        let client = self.client.clone();
 570        let llm_token = self.llm_token.clone();
 571        let app_version = AppVersion::global(cx);
 572        cx.spawn(async move |this, cx| {
 573            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 574                http_client::Url::parse(&predict_edits_url)?
 575            } else {
 576                client
 577                    .http_client()
 578                    .build_zed_llm_url("/predict_edits/accept", &[])?
 579            };
 580
 581            let response = cx
 582                .background_spawn(Self::send_api_request::<()>(
 583                    move |builder| {
 584                        let req = builder.uri(url.as_ref()).body(
 585                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 586                        );
 587                        Ok(req?)
 588                    },
 589                    client,
 590                    llm_token,
 591                    app_version,
 592                ))
 593                .await;
 594
 595            Self::handle_api_response(&this, response, cx)?;
 596            anyhow::Ok(())
 597        })
 598        .detach_and_log_err(cx);
 599    }
 600
 601    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 602        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 603            project_state.current_prediction.take();
 604        };
 605    }
 606
 607    pub fn refresh_prediction(
 608        &mut self,
 609        project: &Entity<Project>,
 610        buffer: &Entity<Buffer>,
 611        position: language::Anchor,
 612        cx: &mut Context<Self>,
 613    ) -> Task<Result<()>> {
 614        let request_task = self.request_prediction(project, buffer, position, cx);
 615        let buffer = buffer.clone();
 616        let project = project.clone();
 617
 618        cx.spawn(async move |this, cx| {
 619            if let Some(prediction) = request_task.await? {
 620                this.update(cx, |this, cx| {
 621                    let project_state = this
 622                        .projects
 623                        .get_mut(&project.entity_id())
 624                        .context("Project not found")?;
 625
 626                    let new_prediction = CurrentEditPrediction {
 627                        requested_by_buffer_id: buffer.entity_id(),
 628                        prediction: prediction,
 629                    };
 630
 631                    if project_state
 632                        .current_prediction
 633                        .as_ref()
 634                        .is_none_or(|old_prediction| {
 635                            new_prediction.should_replace_prediction(&old_prediction, cx)
 636                        })
 637                    {
 638                        project_state.current_prediction = Some(new_prediction);
 639                    }
 640                    anyhow::Ok(())
 641                })??;
 642            }
 643            Ok(())
 644        })
 645    }
 646
 647    pub fn request_prediction(
 648        &mut self,
 649        project: &Entity<Project>,
 650        active_buffer: &Entity<Buffer>,
 651        position: language::Anchor,
 652        cx: &mut Context<Self>,
 653    ) -> Task<Result<Option<EditPrediction>>> {
 654        let project_state = self.projects.get(&project.entity_id());
 655
 656        let index_state = project_state.map(|state| {
 657            state
 658                .syntax_index
 659                .read_with(cx, |index, _cx| index.state().clone())
 660        });
 661        let options = self.options.clone();
 662        let active_snapshot = active_buffer.read(cx).snapshot();
 663        let Some(excerpt_path) = active_snapshot
 664            .file()
 665            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 666        else {
 667            return Task::ready(Err(anyhow!("No file path for excerpt")));
 668        };
 669        let client = self.client.clone();
 670        let llm_token = self.llm_token.clone();
 671        let app_version = AppVersion::global(cx);
 672        let worktree_snapshots = project
 673            .read(cx)
 674            .worktrees(cx)
 675            .map(|worktree| worktree.read(cx).snapshot())
 676            .collect::<Vec<_>>();
 677        let debug_tx = self.debug_tx.clone();
 678
 679        let events = project_state
 680            .map(|state| {
 681                state
 682                    .events
 683                    .iter()
 684                    .filter_map(|event| event.to_request_event(cx))
 685                    .collect::<Vec<_>>()
 686            })
 687            .unwrap_or_default();
 688
 689        let diagnostics = active_snapshot.diagnostic_sets().clone();
 690
 691        let parent_abs_path =
 692            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 693                let mut path = f.worktree.read(cx).absolutize(&f.path);
 694                if path.pop() { Some(path) } else { None }
 695            });
 696
 697        // TODO data collection
 698        let can_collect_data = cx.is_staff();
 699
 700        let mut included_files = project_state
 701            .and_then(|project_state| project_state.context.as_ref())
 702            .unwrap_or(&HashMap::default())
 703            .iter()
 704            .filter_map(|(buffer_entity, ranges)| {
 705                let buffer = buffer_entity.read(cx);
 706                Some((
 707                    buffer_entity.clone(),
 708                    buffer.snapshot(),
 709                    buffer.file()?.full_path(cx).into(),
 710                    ranges.clone(),
 711                ))
 712            })
 713            .collect::<Vec<_>>();
 714
 715        let request_task = cx.background_spawn({
 716            let active_buffer = active_buffer.clone();
 717            async move {
 718                let index_state = if let Some(index_state) = index_state {
 719                    Some(index_state.lock_owned().await)
 720                } else {
 721                    None
 722                };
 723
 724                let cursor_offset = position.to_offset(&active_snapshot);
 725                let cursor_point = cursor_offset.to_point(&active_snapshot);
 726
 727                let before_retrieval = chrono::Utc::now();
 728
 729                let (diagnostic_groups, diagnostic_groups_truncated) =
 730                    Self::gather_nearby_diagnostics(
 731                        cursor_offset,
 732                        &diagnostics,
 733                        &active_snapshot,
 734                        options.max_diagnostic_bytes,
 735                    );
 736
 737                let cloud_request = match options.context {
 738                    ContextMode::Agentic(context_options) => {
 739                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 740                            cursor_point,
 741                            &active_snapshot,
 742                            &context_options.excerpt,
 743                            index_state.as_deref(),
 744                        ) else {
 745                            return Ok((None, None));
 746                        };
 747
 748                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 749                            ..active_snapshot.anchor_before(excerpt.range.end);
 750
 751                        if let Some(buffer_ix) =
 752                            included_files.iter().position(|(_, snapshot, _, _)| {
 753                                snapshot.remote_id() == active_snapshot.remote_id()
 754                            })
 755                        {
 756                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 757                            let range_ix = ranges
 758                                .binary_search_by(|probe| {
 759                                    probe
 760                                        .start
 761                                        .cmp(&excerpt_anchor_range.start, buffer)
 762                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 763                                })
 764                                .unwrap_or_else(|ix| ix);
 765
 766                            ranges.insert(range_ix, excerpt_anchor_range);
 767                            let last_ix = included_files.len() - 1;
 768                            included_files.swap(buffer_ix, last_ix);
 769                        } else {
 770                            included_files.push((
 771                                active_buffer.clone(),
 772                                active_snapshot,
 773                                excerpt_path.clone(),
 774                                vec![excerpt_anchor_range],
 775                            ));
 776                        }
 777
 778                        let included_files = included_files
 779                            .iter()
 780                            .map(|(_, buffer, path, ranges)| {
 781                                let excerpts = merge_excerpts(
 782                                    &buffer,
 783                                    ranges.iter().map(|range| {
 784                                        let point_range = range.to_point(&buffer);
 785                                        Line(point_range.start.row)..Line(point_range.end.row)
 786                                    }),
 787                                );
 788                                predict_edits_v3::IncludedFile {
 789                                    path: path.clone(),
 790                                    max_row: Line(buffer.max_point().row),
 791                                    excerpts,
 792                                }
 793                            })
 794                            .collect::<Vec<_>>();
 795
 796                        predict_edits_v3::PredictEditsRequest {
 797                            excerpt_path,
 798                            excerpt: String::new(),
 799                            excerpt_line_range: Line(0)..Line(0),
 800                            excerpt_range: 0..0,
 801                            cursor_point: predict_edits_v3::Point {
 802                                line: predict_edits_v3::Line(cursor_point.row),
 803                                column: cursor_point.column,
 804                            },
 805                            included_files,
 806                            referenced_declarations: vec![],
 807                            events,
 808                            can_collect_data,
 809                            diagnostic_groups,
 810                            diagnostic_groups_truncated,
 811                            debug_info: debug_tx.is_some(),
 812                            prompt_max_bytes: Some(options.max_prompt_bytes),
 813                            prompt_format: options.prompt_format,
 814                            // TODO [zeta2]
 815                            signatures: vec![],
 816                            excerpt_parent: None,
 817                            git_info: None,
 818                        }
 819                    }
 820                    ContextMode::Syntax(context_options) => {
 821                        let Some(context) = EditPredictionContext::gather_context(
 822                            cursor_point,
 823                            &active_snapshot,
 824                            parent_abs_path.as_deref(),
 825                            &context_options,
 826                            index_state.as_deref(),
 827                        ) else {
 828                            return Ok((None, None));
 829                        };
 830
 831                        make_syntax_context_cloud_request(
 832                            excerpt_path,
 833                            context,
 834                            events,
 835                            can_collect_data,
 836                            diagnostic_groups,
 837                            diagnostic_groups_truncated,
 838                            None,
 839                            debug_tx.is_some(),
 840                            &worktree_snapshots,
 841                            index_state.as_deref(),
 842                            Some(options.max_prompt_bytes),
 843                            options.prompt_format,
 844                        )
 845                    }
 846                };
 847
 848                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 849
 850                let retrieval_time = chrono::Utc::now() - before_retrieval;
 851
 852                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 853                    let (response_tx, response_rx) = oneshot::channel();
 854
 855                    debug_tx
 856                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 857                            ZetaEditPredictionDebugInfo {
 858                                request: cloud_request.clone(),
 859                                retrieval_time,
 860                                buffer: active_buffer.downgrade(),
 861                                local_prompt: match prompt_result.as_ref() {
 862                                    Ok((prompt, _)) => Ok(prompt.clone()),
 863                                    Err(err) => Err(err.to_string()),
 864                                },
 865                                position,
 866                                response_rx,
 867                            },
 868                        ))
 869                        .ok();
 870                    Some(response_tx)
 871                } else {
 872                    None
 873                };
 874
 875                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 876                    if let Some(debug_response_tx) = debug_response_tx {
 877                        debug_response_tx
 878                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 879                            .ok();
 880                    }
 881                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 882                }
 883
 884                let (prompt, _) = prompt_result?;
 885                let request = open_ai::Request {
 886                    model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()),
 887                    messages: vec![open_ai::RequestMessage::User {
 888                        content: open_ai::MessageContent::Plain(prompt),
 889                    }],
 890                    stream: false,
 891                    max_completion_tokens: None,
 892                    stop: Default::default(),
 893                    temperature: 0.7,
 894                    tool_choice: None,
 895                    parallel_tool_calls: None,
 896                    tools: vec![],
 897                    prompt_cache_key: None,
 898                    reasoning_effort: None,
 899                };
 900
 901                log::trace!("Sending edit prediction request");
 902
 903                let before_request = chrono::Utc::now();
 904                let response =
 905                    Self::send_raw_llm_request(client, llm_token, app_version, request).await;
 906                let request_time = chrono::Utc::now() - before_request;
 907
 908                log::trace!("Got edit prediction response");
 909
 910                if let Some(debug_response_tx) = debug_response_tx {
 911                    debug_response_tx
 912                        .send((
 913                            response
 914                                .as_ref()
 915                                .map_err(|err| err.to_string())
 916                                .map(|response| response.0.clone()),
 917                            request_time,
 918                        ))
 919                        .ok();
 920                }
 921
 922                let (res, usage) = response?;
 923                let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
 924                let Some(output_text) = text_from_response(res) else {
 925                    return Ok((None, usage))
 926                };
 927
 928                let (edited_buffer_snapshot, edits) =
 929                    crate::udiff::parse_diff(&output_text, |path| {
 930                        included_files
 931                            .iter()
 932                            .find_map(|(_, buffer, probe_path, ranges)| {
 933                                if probe_path.as_ref() == path {
 934                                    Some((buffer, ranges.as_slice()))
 935                                } else {
 936                                    None
 937                                }
 938                            })
 939                    })
 940                    .await?;
 941
 942                let edited_buffer = included_files
 943                    .iter()
 944                    .find_map(|(buffer, snapshot, _, _)| {
 945                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
 946                            Some(buffer.clone())
 947                        } else {
 948                            None
 949                        }
 950                    })
 951                    .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
 952
 953                anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
 954            }
 955        });
 956
 957        cx.spawn({
 958            async move |this, cx| {
 959                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
 960                    Self::handle_api_response(&this, request_task.await, cx)?
 961                else {
 962                    return Ok(None);
 963                };
 964
 965                // TODO telemetry: duration, etc
 966                Ok(
 967                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
 968                        .await,
 969                )
 970            }
 971        })
 972    }
 973
 974    async fn send_raw_llm_request(
 975        client: Arc<Client>,
 976        llm_token: LlmApiToken,
 977        app_version: SemanticVersion,
 978        request: open_ai::Request,
 979    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
 980        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 981            http_client::Url::parse(&predict_edits_url)?
 982        } else {
 983            client
 984                .http_client()
 985                .build_zed_llm_url("/predict_edits/raw", &[])?
 986        };
 987
 988        Self::send_api_request(
 989            |builder| {
 990                let req = builder
 991                    .uri(url.as_ref())
 992                    .body(serde_json::to_string(&request)?.into());
 993                Ok(req?)
 994            },
 995            client,
 996            llm_token,
 997            app_version,
 998        )
 999        .await
1000    }
1001
1002    fn handle_api_response<T>(
1003        this: &WeakEntity<Self>,
1004        response: Result<(T, Option<EditPredictionUsage>)>,
1005        cx: &mut gpui::AsyncApp,
1006    ) -> Result<T> {
1007        match response {
1008            Ok((data, usage)) => {
1009                if let Some(usage) = usage {
1010                    this.update(cx, |this, cx| {
1011                        this.user_store.update(cx, |user_store, cx| {
1012                            user_store.update_edit_prediction_usage(usage, cx);
1013                        });
1014                    })
1015                    .ok();
1016                }
1017                Ok(data)
1018            }
1019            Err(err) => {
1020                if err.is::<ZedUpdateRequiredError>() {
1021                    cx.update(|cx| {
1022                        this.update(cx, |this, _cx| {
1023                            this.update_required = true;
1024                        })
1025                        .ok();
1026
1027                        let error_message: SharedString = err.to_string().into();
1028                        show_app_notification(
1029                            NotificationId::unique::<ZedUpdateRequiredError>(),
1030                            cx,
1031                            move |cx| {
1032                                cx.new(|cx| {
1033                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1034                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1035                                })
1036                            },
1037                        );
1038                    })
1039                    .ok();
1040                }
1041                Err(err)
1042            }
1043        }
1044    }
1045
1046    async fn send_api_request<Res>(
1047        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1048        client: Arc<Client>,
1049        llm_token: LlmApiToken,
1050        app_version: SemanticVersion,
1051    ) -> Result<(Res, Option<EditPredictionUsage>)>
1052    where
1053        Res: DeserializeOwned,
1054    {
1055        let http_client = client.http_client();
1056        let mut token = llm_token.acquire(&client).await?;
1057        let mut did_retry = false;
1058
1059        loop {
1060            let request_builder = http_client::Request::builder().method(Method::POST);
1061
1062            let request = build(
1063                request_builder
1064                    .header("Content-Type", "application/json")
1065                    .header("Authorization", format!("Bearer {}", token))
1066                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1067            )?;
1068
1069            let mut response = http_client.send(request).await?;
1070
1071            if let Some(minimum_required_version) = response
1072                .headers()
1073                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1074                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1075            {
1076                anyhow::ensure!(
1077                    app_version >= minimum_required_version,
1078                    ZedUpdateRequiredError {
1079                        minimum_version: minimum_required_version
1080                    }
1081                );
1082            }
1083
1084            if response.status().is_success() {
1085                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1086
1087                let mut body = Vec::new();
1088                response.body_mut().read_to_end(&mut body).await?;
1089                return Ok((serde_json::from_slice(&body)?, usage));
1090            } else if !did_retry
1091                && response
1092                    .headers()
1093                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1094                    .is_some()
1095            {
1096                did_retry = true;
1097                token = llm_token.refresh(&client).await?;
1098            } else {
1099                let mut body = String::new();
1100                response.body_mut().read_to_string(&mut body).await?;
1101                anyhow::bail!(
1102                    "Request failed with status: {:?}\nBody: {}",
1103                    response.status(),
1104                    body
1105                );
1106            }
1107        }
1108    }
1109
1110    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1111    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1112
1113    // Refresh the related excerpts when the user just beguns editing after
1114    // an idle period, and after they pause editing.
1115    fn refresh_context_if_needed(
1116        &mut self,
1117        project: &Entity<Project>,
1118        buffer: &Entity<language::Buffer>,
1119        cursor_position: language::Anchor,
1120        cx: &mut Context<Self>,
1121    ) {
1122        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1123            return;
1124        }
1125
1126        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1127            return;
1128        };
1129
1130        let now = Instant::now();
1131        let was_idle = zeta_project
1132            .refresh_context_timestamp
1133            .map_or(true, |timestamp| {
1134                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1135            });
1136        zeta_project.refresh_context_timestamp = Some(now);
1137        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1138            let buffer = buffer.clone();
1139            let project = project.clone();
1140            async move |this, cx| {
1141                if was_idle {
1142                    log::debug!("refetching edit prediction context after idle");
1143                } else {
1144                    cx.background_executor()
1145                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1146                        .await;
1147                    log::debug!("refetching edit prediction context after pause");
1148                }
1149                this.update(cx, |this, cx| {
1150                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1151
1152                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1153                        zeta_project.refresh_context_task = Some(task.log_err());
1154                    };
1155                })
1156                .ok()
1157            }
1158        }));
1159    }
1160
1161    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1162    // and avoid spawning more than one concurrent task.
1163    pub fn refresh_context(
1164        &mut self,
1165        project: Entity<Project>,
1166        buffer: Entity<language::Buffer>,
1167        cursor_position: language::Anchor,
1168        cx: &mut Context<Self>,
1169    ) -> Task<Result<()>> {
1170        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1171            return Task::ready(anyhow::Ok(()));
1172        };
1173
1174        let ContextMode::Agentic(options) = &self.options().context else {
1175            return Task::ready(anyhow::Ok(()));
1176        };
1177
1178        let snapshot = buffer.read(cx).snapshot();
1179        let cursor_point = cursor_position.to_point(&snapshot);
1180        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1181            cursor_point,
1182            &snapshot,
1183            &options.excerpt,
1184            None,
1185        ) else {
1186            return Task::ready(Ok(()));
1187        };
1188
1189        let app_version = AppVersion::global(cx);
1190        let client = self.client.clone();
1191        let llm_token = self.llm_token.clone();
1192        let debug_tx = self.debug_tx.clone();
1193        let current_file_path: Arc<Path> = snapshot
1194            .file()
1195            .map(|f| f.full_path(cx).into())
1196            .unwrap_or_else(|| Path::new("untitled").into());
1197
1198        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1199            predict_edits_v3::PlanContextRetrievalRequest {
1200                excerpt: cursor_excerpt.text(&snapshot).body,
1201                excerpt_path: current_file_path,
1202                excerpt_line_range: cursor_excerpt.line_range,
1203                cursor_file_max_row: Line(snapshot.max_point().row),
1204                events: zeta_project
1205                    .events
1206                    .iter()
1207                    .filter_map(|ev| ev.to_request_event(cx))
1208                    .collect(),
1209            },
1210        ) {
1211            Ok(prompt) => prompt,
1212            Err(err) => {
1213                return Task::ready(Err(err));
1214            }
1215        };
1216
1217        if let Some(debug_tx) = &debug_tx {
1218            debug_tx
1219                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1220                    ZetaContextRetrievalStartedDebugInfo {
1221                        project: project.clone(),
1222                        timestamp: Instant::now(),
1223                        search_prompt: prompt.clone(),
1224                    },
1225                ))
1226                .ok();
1227        }
1228
1229        let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
1230
1231        let request = open_ai::Request {
1232            model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()),
1233            messages: vec![open_ai::RequestMessage::User {
1234                content: open_ai::MessageContent::Plain(prompt),
1235            }],
1236            stream: false,
1237            max_completion_tokens: None,
1238            stop: Default::default(),
1239            temperature: 0.7,
1240            tool_choice: None,
1241            parallel_tool_calls: None,
1242            tools: vec![open_ai::ToolDefinition::Function {
1243                function: FunctionDefinition {
1244                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1245                    description: Some(tool_description.clone()),
1246                    parameters: Some(tool_schema.clone()),
1247                },
1248            }],
1249            prompt_cache_key: None,
1250            reasoning_effort: None,
1251        };
1252
1253        cx.spawn(async move |this, cx| {
1254            log::trace!("Sending search planning request");
1255            let response =
1256                Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1257            let mut response = Self::handle_api_response(&this, response, cx)?;
1258
1259            log::trace!("Got search planning response");
1260
1261            let choice = response
1262                .choices
1263                .pop()
1264                .context("No choices in retrieval response")?;
1265            let open_ai::RequestMessage::Assistant {
1266                content: _,
1267                tool_calls,
1268            } = choice.message
1269            else {
1270                anyhow::bail!("Retrieval response didn't include an assistant message");
1271            };
1272
1273            let mut regex_by_glob: HashMap<String, String> = HashMap::default();
1274            for tool_call in tool_calls {
1275                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1276                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1277                    log::warn!(
1278                        "Context retrieval response tried to call an unknown tool: {}",
1279                        function.name
1280                    );
1281
1282                    continue;
1283                }
1284
1285                let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1286                for query in input.queries {
1287                    let regex = regex_by_glob.entry(query.glob).or_default();
1288                    if !regex.is_empty() {
1289                        regex.push('|');
1290                    }
1291                    regex.push_str(&query.regex);
1292                }
1293            }
1294
1295            if let Some(debug_tx) = &debug_tx {
1296                debug_tx
1297                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1298                        ZetaSearchQueryDebugInfo {
1299                            project: project.clone(),
1300                            timestamp: Instant::now(),
1301                            regex_by_glob: regex_by_glob.clone(),
1302                        },
1303                    ))
1304                    .ok();
1305            }
1306
1307            log::trace!("Running retrieval search: {regex_by_glob:#?}");
1308
1309            let related_excerpts_result =
1310                retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await;
1311
1312            log::trace!("Search queries executed");
1313
1314            if let Some(debug_tx) = &debug_tx {
1315                debug_tx
1316                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1317                        ZetaContextRetrievalDebugInfo {
1318                            project: project.clone(),
1319                            timestamp: Instant::now(),
1320                        },
1321                    ))
1322                    .ok();
1323            }
1324
1325            this.update(cx, |this, _cx| {
1326                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1327                    return Ok(());
1328                };
1329                zeta_project.refresh_context_task.take();
1330                if let Some(debug_tx) = &this.debug_tx {
1331                    debug_tx
1332                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1333                            ZetaContextRetrievalDebugInfo {
1334                                project,
1335                                timestamp: Instant::now(),
1336                            },
1337                        ))
1338                        .ok();
1339                }
1340                match related_excerpts_result {
1341                    Ok(excerpts) => {
1342                        zeta_project.context = Some(excerpts);
1343                        Ok(())
1344                    }
1345                    Err(error) => Err(error),
1346                }
1347            })?
1348        })
1349    }
1350
1351    fn gather_nearby_diagnostics(
1352        cursor_offset: usize,
1353        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1354        snapshot: &BufferSnapshot,
1355        max_diagnostics_bytes: usize,
1356    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1357        // TODO: Could make this more efficient
1358        let mut diagnostic_groups = Vec::new();
1359        for (language_server_id, diagnostics) in diagnostic_sets {
1360            let mut groups = Vec::new();
1361            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1362            diagnostic_groups.extend(
1363                groups
1364                    .into_iter()
1365                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1366            );
1367        }
1368
1369        // sort by proximity to cursor
1370        diagnostic_groups.sort_by_key(|group| {
1371            let range = &group.entries[group.primary_ix].range;
1372            if range.start >= cursor_offset {
1373                range.start - cursor_offset
1374            } else if cursor_offset >= range.end {
1375                cursor_offset - range.end
1376            } else {
1377                (cursor_offset - range.start).min(range.end - cursor_offset)
1378            }
1379        });
1380
1381        let mut results = Vec::new();
1382        let mut diagnostic_groups_truncated = false;
1383        let mut diagnostics_byte_count = 0;
1384        for group in diagnostic_groups {
1385            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1386            diagnostics_byte_count += raw_value.get().len();
1387            if diagnostics_byte_count > max_diagnostics_bytes {
1388                diagnostic_groups_truncated = true;
1389                break;
1390            }
1391            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1392        }
1393
1394        (results, diagnostic_groups_truncated)
1395    }
1396
1397    // TODO: Dedupe with similar code in request_prediction?
1398    pub fn cloud_request_for_zeta_cli(
1399        &mut self,
1400        project: &Entity<Project>,
1401        buffer: &Entity<Buffer>,
1402        position: language::Anchor,
1403        cx: &mut Context<Self>,
1404    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1405        let project_state = self.projects.get(&project.entity_id());
1406
1407        let index_state = project_state.map(|state| {
1408            state
1409                .syntax_index
1410                .read_with(cx, |index, _cx| index.state().clone())
1411        });
1412        let options = self.options.clone();
1413        let snapshot = buffer.read(cx).snapshot();
1414        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1415            return Task::ready(Err(anyhow!("No file path for excerpt")));
1416        };
1417        let worktree_snapshots = project
1418            .read(cx)
1419            .worktrees(cx)
1420            .map(|worktree| worktree.read(cx).snapshot())
1421            .collect::<Vec<_>>();
1422
1423        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1424            let mut path = f.worktree.read(cx).absolutize(&f.path);
1425            if path.pop() { Some(path) } else { None }
1426        });
1427
1428        cx.background_spawn(async move {
1429            let index_state = if let Some(index_state) = index_state {
1430                Some(index_state.lock_owned().await)
1431            } else {
1432                None
1433            };
1434
1435            let cursor_point = position.to_point(&snapshot);
1436
1437            let debug_info = true;
1438            EditPredictionContext::gather_context(
1439                cursor_point,
1440                &snapshot,
1441                parent_abs_path.as_deref(),
1442                match &options.context {
1443                    ContextMode::Agentic(_) => {
1444                        // TODO
1445                        panic!("Llm mode not supported in zeta cli yet");
1446                    }
1447                    ContextMode::Syntax(edit_prediction_context_options) => {
1448                        edit_prediction_context_options
1449                    }
1450                },
1451                index_state.as_deref(),
1452            )
1453            .context("Failed to select excerpt")
1454            .map(|context| {
1455                make_syntax_context_cloud_request(
1456                    excerpt_path.into(),
1457                    context,
1458                    // TODO pass everything
1459                    Vec::new(),
1460                    false,
1461                    Vec::new(),
1462                    false,
1463                    None,
1464                    debug_info,
1465                    &worktree_snapshots,
1466                    index_state.as_deref(),
1467                    Some(options.max_prompt_bytes),
1468                    options.prompt_format,
1469                )
1470            })
1471        })
1472    }
1473
1474    pub fn wait_for_initial_indexing(
1475        &mut self,
1476        project: &Entity<Project>,
1477        cx: &mut App,
1478    ) -> Task<Result<()>> {
1479        let zeta_project = self.get_or_init_zeta_project(project, cx);
1480        zeta_project
1481            .syntax_index
1482            .read(cx)
1483            .wait_for_initial_file_indexing(cx)
1484    }
1485}
1486
1487pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1488    let choice = res.choices.pop()?;
1489    let output_text = match choice.message {
1490        open_ai::RequestMessage::Assistant {
1491            content: Some(open_ai::MessageContent::Plain(content)),
1492            ..
1493        } => content,
1494        open_ai::RequestMessage::Assistant {
1495            content: Some(open_ai::MessageContent::Multipart(mut content)),
1496            ..
1497        } => {
1498            if content.is_empty() {
1499                log::error!("No output from Baseten completion response");
1500                return None;
1501            }
1502
1503            match content.remove(0) {
1504                open_ai::MessagePart::Text { text } => text,
1505                open_ai::MessagePart::Image { .. } => {
1506                    log::error!("Expected text, got an image");
1507                    return None;
1508                }
1509            }
1510        }
1511        _ => {
1512            log::error!("Invalid response message: {:?}", choice.message);
1513            return None;
1514        }
1515    };
1516    Some(output_text)
1517}
1518
1519#[derive(Error, Debug)]
1520#[error(
1521    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1522)]
1523pub struct ZedUpdateRequiredError {
1524    minimum_version: SemanticVersion,
1525}
1526
1527fn make_syntax_context_cloud_request(
1528    excerpt_path: Arc<Path>,
1529    context: EditPredictionContext,
1530    events: Vec<predict_edits_v3::Event>,
1531    can_collect_data: bool,
1532    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1533    diagnostic_groups_truncated: bool,
1534    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1535    debug_info: bool,
1536    worktrees: &Vec<worktree::Snapshot>,
1537    index_state: Option<&SyntaxIndexState>,
1538    prompt_max_bytes: Option<usize>,
1539    prompt_format: PromptFormat,
1540) -> predict_edits_v3::PredictEditsRequest {
1541    let mut signatures = Vec::new();
1542    let mut declaration_to_signature_index = HashMap::default();
1543    let mut referenced_declarations = Vec::new();
1544
1545    for snippet in context.declarations {
1546        let project_entry_id = snippet.declaration.project_entry_id();
1547        let Some(path) = worktrees.iter().find_map(|worktree| {
1548            worktree.entry_for_id(project_entry_id).map(|entry| {
1549                let mut full_path = RelPathBuf::new();
1550                full_path.push(worktree.root_name());
1551                full_path.push(&entry.path);
1552                full_path
1553            })
1554        }) else {
1555            continue;
1556        };
1557
1558        let parent_index = index_state.and_then(|index_state| {
1559            snippet.declaration.parent().and_then(|parent| {
1560                add_signature(
1561                    parent,
1562                    &mut declaration_to_signature_index,
1563                    &mut signatures,
1564                    index_state,
1565                )
1566            })
1567        });
1568
1569        let (text, text_is_truncated) = snippet.declaration.item_text();
1570        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1571            path: path.as_std_path().into(),
1572            text: text.into(),
1573            range: snippet.declaration.item_line_range(),
1574            text_is_truncated,
1575            signature_range: snippet.declaration.signature_range_in_item_text(),
1576            parent_index,
1577            signature_score: snippet.score(DeclarationStyle::Signature),
1578            declaration_score: snippet.score(DeclarationStyle::Declaration),
1579            score_components: snippet.components,
1580        });
1581    }
1582
1583    let excerpt_parent = index_state.and_then(|index_state| {
1584        context
1585            .excerpt
1586            .parent_declarations
1587            .last()
1588            .and_then(|(parent, _)| {
1589                add_signature(
1590                    *parent,
1591                    &mut declaration_to_signature_index,
1592                    &mut signatures,
1593                    index_state,
1594                )
1595            })
1596    });
1597
1598    predict_edits_v3::PredictEditsRequest {
1599        excerpt_path,
1600        excerpt: context.excerpt_text.body,
1601        excerpt_line_range: context.excerpt.line_range,
1602        excerpt_range: context.excerpt.range,
1603        cursor_point: predict_edits_v3::Point {
1604            line: predict_edits_v3::Line(context.cursor_point.row),
1605            column: context.cursor_point.column,
1606        },
1607        referenced_declarations,
1608        included_files: vec![],
1609        signatures,
1610        excerpt_parent,
1611        events,
1612        can_collect_data,
1613        diagnostic_groups,
1614        diagnostic_groups_truncated,
1615        git_info,
1616        debug_info,
1617        prompt_max_bytes,
1618        prompt_format,
1619    }
1620}
1621
1622fn add_signature(
1623    declaration_id: DeclarationId,
1624    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1625    signatures: &mut Vec<Signature>,
1626    index: &SyntaxIndexState,
1627) -> Option<usize> {
1628    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1629        return Some(*signature_index);
1630    }
1631    let Some(parent_declaration) = index.declaration(declaration_id) else {
1632        log::error!("bug: missing parent declaration");
1633        return None;
1634    };
1635    let parent_index = parent_declaration.parent().and_then(|parent| {
1636        add_signature(parent, declaration_to_signature_index, signatures, index)
1637    });
1638    let (text, text_is_truncated) = parent_declaration.signature_text();
1639    let signature_index = signatures.len();
1640    signatures.push(Signature {
1641        text: text.into(),
1642        text_is_truncated,
1643        parent_index,
1644        range: parent_declaration.signature_line_range(),
1645    });
1646    declaration_to_signature_index.insert(declaration_id, signature_index);
1647    Some(signature_index)
1648}
1649
1650#[cfg(test)]
1651mod tests {
1652    use std::{path::Path, sync::Arc};
1653
1654    use client::UserStore;
1655    use clock::FakeSystemClock;
1656    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1657    use futures::{
1658        AsyncReadExt, StreamExt,
1659        channel::{mpsc, oneshot},
1660    };
1661    use gpui::{
1662        Entity, TestAppContext,
1663        http_client::{FakeHttpClient, Response},
1664        prelude::*,
1665    };
1666    use indoc::indoc;
1667    use language::OffsetRangeExt as _;
1668    use open_ai::Usage;
1669    use pretty_assertions::{assert_eq, assert_matches};
1670    use project::{FakeFs, Project};
1671    use serde_json::json;
1672    use settings::SettingsStore;
1673    use util::path;
1674    use uuid::Uuid;
1675
1676    use crate::{BufferEditPrediction, Zeta};
1677
1678    #[gpui::test]
1679    async fn test_current_state(cx: &mut TestAppContext) {
1680        let (zeta, mut req_rx) = init_test(cx);
1681        let fs = FakeFs::new(cx.executor());
1682        fs.insert_tree(
1683            "/root",
1684            json!({
1685                "1.txt": "Hello!\nHow\nBye\n",
1686                "2.txt": "Hola!\nComo\nAdios\n"
1687            }),
1688        )
1689        .await;
1690        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1691
1692        zeta.update(cx, |zeta, cx| {
1693            zeta.register_project(&project, cx);
1694        });
1695
1696        let buffer1 = project
1697            .update(cx, |project, cx| {
1698                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1699                project.open_buffer(path, cx)
1700            })
1701            .await
1702            .unwrap();
1703        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1704        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1705
1706        // Prediction for current file
1707
1708        let prediction_task = zeta.update(cx, |zeta, cx| {
1709            zeta.refresh_prediction(&project, &buffer1, position, cx)
1710        });
1711        let (_request, respond_tx) = req_rx.next().await.unwrap();
1712
1713        respond_tx
1714            .send(model_response(indoc! {r"
1715                --- a/root/1.txt
1716                +++ b/root/1.txt
1717                @@ ... @@
1718                 Hello!
1719                -How
1720                +How are you?
1721                 Bye
1722            "}))
1723            .unwrap();
1724        prediction_task.await.unwrap();
1725
1726        zeta.read_with(cx, |zeta, cx| {
1727            let prediction = zeta
1728                .current_prediction_for_buffer(&buffer1, &project, cx)
1729                .unwrap();
1730            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1731        });
1732
1733        // Context refresh
1734        let refresh_task = zeta.update(cx, |zeta, cx| {
1735            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1736        });
1737        let (_request, respond_tx) = req_rx.next().await.unwrap();
1738        respond_tx
1739            .send(open_ai::Response {
1740                id: Uuid::new_v4().to_string(),
1741                object: "response".into(),
1742                created: 0,
1743                model: "model".into(),
1744                choices: vec![open_ai::Choice {
1745                    index: 0,
1746                    message: open_ai::RequestMessage::Assistant {
1747                        content: None,
1748                        tool_calls: vec![open_ai::ToolCall {
1749                            id: "search".into(),
1750                            content: open_ai::ToolCallContent::Function {
1751                                function: open_ai::FunctionContent {
1752                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1753                                        .to_string(),
1754                                    arguments: serde_json::to_string(&SearchToolInput {
1755                                        queries: Box::new([SearchToolQuery {
1756                                            glob: "root/2.txt".to_string(),
1757                                            regex: ".".to_string(),
1758                                        }]),
1759                                    })
1760                                    .unwrap(),
1761                                },
1762                            },
1763                        }],
1764                    },
1765                    finish_reason: None,
1766                }],
1767                usage: Usage {
1768                    prompt_tokens: 0,
1769                    completion_tokens: 0,
1770                    total_tokens: 0,
1771                },
1772            })
1773            .unwrap();
1774        refresh_task.await.unwrap();
1775
1776        zeta.update(cx, |zeta, _cx| {
1777            zeta.discard_current_prediction(&project);
1778        });
1779
1780        // Prediction for another file
1781        let prediction_task = zeta.update(cx, |zeta, cx| {
1782            zeta.refresh_prediction(&project, &buffer1, position, cx)
1783        });
1784        let (_request, respond_tx) = req_rx.next().await.unwrap();
1785        respond_tx
1786            .send(model_response(indoc! {r#"
1787                --- a/root/2.txt
1788                +++ b/root/2.txt
1789                 Hola!
1790                -Como
1791                +Como estas?
1792                 Adios
1793            "#}))
1794            .unwrap();
1795        prediction_task.await.unwrap();
1796        zeta.read_with(cx, |zeta, cx| {
1797            let prediction = zeta
1798                .current_prediction_for_buffer(&buffer1, &project, cx)
1799                .unwrap();
1800            assert_matches!(
1801                prediction,
1802                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1803            );
1804        });
1805
1806        let buffer2 = project
1807            .update(cx, |project, cx| {
1808                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1809                project.open_buffer(path, cx)
1810            })
1811            .await
1812            .unwrap();
1813
1814        zeta.read_with(cx, |zeta, cx| {
1815            let prediction = zeta
1816                .current_prediction_for_buffer(&buffer2, &project, cx)
1817                .unwrap();
1818            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1819        });
1820    }
1821
1822    #[gpui::test]
1823    async fn test_simple_request(cx: &mut TestAppContext) {
1824        let (zeta, mut req_rx) = init_test(cx);
1825        let fs = FakeFs::new(cx.executor());
1826        fs.insert_tree(
1827            "/root",
1828            json!({
1829                "foo.md":  "Hello!\nHow\nBye\n"
1830            }),
1831        )
1832        .await;
1833        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1834
1835        let buffer = project
1836            .update(cx, |project, cx| {
1837                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1838                project.open_buffer(path, cx)
1839            })
1840            .await
1841            .unwrap();
1842        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1843        let position = snapshot.anchor_before(language::Point::new(1, 3));
1844
1845        let prediction_task = zeta.update(cx, |zeta, cx| {
1846            zeta.request_prediction(&project, &buffer, position, cx)
1847        });
1848
1849        let (_, respond_tx) = req_rx.next().await.unwrap();
1850
1851        // TODO Put back when we have a structured request again
1852        // assert_eq!(
1853        //     request.excerpt_path.as_ref(),
1854        //     Path::new(path!("root/foo.md"))
1855        // );
1856        // assert_eq!(
1857        //     request.cursor_point,
1858        //     Point {
1859        //         line: Line(1),
1860        //         column: 3
1861        //     }
1862        // );
1863
1864        respond_tx
1865            .send(model_response(indoc! { r"
1866                --- a/root/foo.md
1867                +++ b/root/foo.md
1868                @@ ... @@
1869                 Hello!
1870                -How
1871                +How are you?
1872                 Bye
1873            "}))
1874            .unwrap();
1875
1876        let prediction = prediction_task.await.unwrap().unwrap();
1877
1878        assert_eq!(prediction.edits.len(), 1);
1879        assert_eq!(
1880            prediction.edits[0].0.to_point(&snapshot).start,
1881            language::Point::new(1, 3)
1882        );
1883        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1884    }
1885
1886    #[gpui::test]
1887    async fn test_request_events(cx: &mut TestAppContext) {
1888        let (zeta, mut req_rx) = init_test(cx);
1889        let fs = FakeFs::new(cx.executor());
1890        fs.insert_tree(
1891            "/root",
1892            json!({
1893                "foo.md": "Hello!\n\nBye\n"
1894            }),
1895        )
1896        .await;
1897        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1898
1899        let buffer = project
1900            .update(cx, |project, cx| {
1901                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1902                project.open_buffer(path, cx)
1903            })
1904            .await
1905            .unwrap();
1906
1907        zeta.update(cx, |zeta, cx| {
1908            zeta.register_buffer(&buffer, &project, cx);
1909        });
1910
1911        buffer.update(cx, |buffer, cx| {
1912            buffer.edit(vec![(7..7, "How")], None, cx);
1913        });
1914
1915        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1916        let position = snapshot.anchor_before(language::Point::new(1, 3));
1917
1918        let prediction_task = zeta.update(cx, |zeta, cx| {
1919            zeta.request_prediction(&project, &buffer, position, cx)
1920        });
1921
1922        let (request, respond_tx) = req_rx.next().await.unwrap();
1923
1924        let prompt = prompt_from_request(&request);
1925        assert!(
1926            prompt.contains(indoc! {"
1927            --- a/root/foo.md
1928            +++ b/root/foo.md
1929            @@ -1,3 +1,3 @@
1930             Hello!
1931            -
1932            +How
1933             Bye
1934        "}),
1935            "{prompt}"
1936        );
1937
1938        respond_tx
1939            .send(model_response(indoc! {r#"
1940                --- a/root/foo.md
1941                +++ b/root/foo.md
1942                @@ ... @@
1943                 Hello!
1944                -How
1945                +How are you?
1946                 Bye
1947            "#}))
1948            .unwrap();
1949
1950        let prediction = prediction_task.await.unwrap().unwrap();
1951
1952        assert_eq!(prediction.edits.len(), 1);
1953        assert_eq!(
1954            prediction.edits[0].0.to_point(&snapshot).start,
1955            language::Point::new(1, 3)
1956        );
1957        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1958    }
1959
1960    // Skipped until we start including diagnostics in prompt
1961    // #[gpui::test]
1962    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1963    //     let (zeta, mut req_rx) = init_test(cx);
1964    //     let fs = FakeFs::new(cx.executor());
1965    //     fs.insert_tree(
1966    //         "/root",
1967    //         json!({
1968    //             "foo.md": "Hello!\nBye"
1969    //         }),
1970    //     )
1971    //     .await;
1972    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1973
1974    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1975    //     let diagnostic = lsp::Diagnostic {
1976    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1977    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
1978    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1979    //         ..Default::default()
1980    //     };
1981
1982    //     project.update(cx, |project, cx| {
1983    //         project.lsp_store().update(cx, |lsp_store, cx| {
1984    //             // Create some diagnostics
1985    //             lsp_store
1986    //                 .update_diagnostics(
1987    //                     LanguageServerId(0),
1988    //                     lsp::PublishDiagnosticsParams {
1989    //                         uri: path_to_buffer_uri.clone(),
1990    //                         diagnostics: vec![diagnostic],
1991    //                         version: None,
1992    //                     },
1993    //                     None,
1994    //                     language::DiagnosticSourceKind::Pushed,
1995    //                     &[],
1996    //                     cx,
1997    //                 )
1998    //                 .unwrap();
1999    //         });
2000    //     });
2001
2002    //     let buffer = project
2003    //         .update(cx, |project, cx| {
2004    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2005    //             project.open_buffer(path, cx)
2006    //         })
2007    //         .await
2008    //         .unwrap();
2009
2010    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2011    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2012
2013    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2014    //         zeta.request_prediction(&project, &buffer, position, cx)
2015    //     });
2016
2017    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2018
2019    //     assert_eq!(request.diagnostic_groups.len(), 1);
2020    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2021    //         .unwrap();
2022    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2023    //     assert_eq!(
2024    //         value,
2025    //         json!({
2026    //             "entries": [{
2027    //                 "range": {
2028    //                     "start": 8,
2029    //                     "end": 10
2030    //                 },
2031    //                 "diagnostic": {
2032    //                     "source": null,
2033    //                     "code": null,
2034    //                     "code_description": null,
2035    //                     "severity": 1,
2036    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2037    //                     "markdown": null,
2038    //                     "group_id": 0,
2039    //                     "is_primary": true,
2040    //                     "is_disk_based": false,
2041    //                     "is_unnecessary": false,
2042    //                     "source_kind": "Pushed",
2043    //                     "data": null,
2044    //                     "underline": true
2045    //                 }
2046    //             }],
2047    //             "primary_ix": 0
2048    //         })
2049    //     );
2050    // }
2051
2052    fn model_response(text: &str) -> open_ai::Response {
2053        open_ai::Response {
2054            id: Uuid::new_v4().to_string(),
2055            object: "response".into(),
2056            created: 0,
2057            model: "model".into(),
2058            choices: vec![open_ai::Choice {
2059                index: 0,
2060                message: open_ai::RequestMessage::Assistant {
2061                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2062                    tool_calls: vec![],
2063                },
2064                finish_reason: None,
2065            }],
2066            usage: Usage {
2067                prompt_tokens: 0,
2068                completion_tokens: 0,
2069                total_tokens: 0,
2070            },
2071        }
2072    }
2073
2074    fn prompt_from_request(request: &open_ai::Request) -> &str {
2075        assert_eq!(request.messages.len(), 1);
2076        let open_ai::RequestMessage::User {
2077            content: open_ai::MessageContent::Plain(content),
2078            ..
2079        } = &request.messages[0]
2080        else {
2081            panic!(
2082                "Request does not have single user message of type Plain. {:#?}",
2083                request
2084            );
2085        };
2086        content
2087    }
2088
2089    fn init_test(
2090        cx: &mut TestAppContext,
2091    ) -> (
2092        Entity<Zeta>,
2093        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2094    ) {
2095        cx.update(move |cx| {
2096            let settings_store = SettingsStore::test(cx);
2097            cx.set_global(settings_store);
2098            zlog::init_test();
2099
2100            let (req_tx, req_rx) = mpsc::unbounded();
2101
2102            let http_client = FakeHttpClient::create({
2103                move |req| {
2104                    let uri = req.uri().path().to_string();
2105                    let mut body = req.into_body();
2106                    let req_tx = req_tx.clone();
2107                    async move {
2108                        let resp = match uri.as_str() {
2109                            "/client/llm_tokens" => serde_json::to_string(&json!({
2110                                "token": "test"
2111                            }))
2112                            .unwrap(),
2113                            "/predict_edits/raw" => {
2114                                let mut buf = Vec::new();
2115                                body.read_to_end(&mut buf).await.ok();
2116                                let req = serde_json::from_slice(&buf).unwrap();
2117
2118                                let (res_tx, res_rx) = oneshot::channel();
2119                                req_tx.unbounded_send((req, res_tx)).unwrap();
2120                                serde_json::to_string(&res_rx.await?).unwrap()
2121                            }
2122                            _ => {
2123                                panic!("Unexpected path: {}", uri)
2124                            }
2125                        };
2126
2127                        Ok(Response::builder().body(resp.into()).unwrap())
2128                    }
2129                }
2130            });
2131
2132            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2133            client.cloud_client().set_credentials(1, "test".into());
2134
2135            language_model::init(client.clone(), cx);
2136
2137            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2138            let zeta = Zeta::global(&client, &user_store, cx);
2139
2140            (zeta, req_rx)
2141        })
2142    }
2143}