zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33use uuid::Uuid;
  34
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50
  51use crate::merge_excerpts::merge_excerpts;
  52use crate::prediction::EditPrediction;
  53pub use crate::prediction::EditPredictionId;
  54pub use provider::ZetaEditPredictionProvider;
  55
  56/// Maximum number of events to track.
  57const MAX_EVENT_COUNT: usize = 16;
  58
  59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  60    max_bytes: 512,
  61    min_bytes: 128,
  62    target_before_cursor_over_total_bytes: 0.5,
  63};
  64
  65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  66    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  67
  68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  69    excerpt: DEFAULT_EXCERPT_OPTIONS,
  70};
  71
  72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  73    EditPredictionContextOptions {
  74        use_imports: true,
  75        max_retrieved_declarations: 0,
  76        excerpt: DEFAULT_EXCERPT_OPTIONS,
  77        score: EditPredictionScoreOptions {
  78            omit_excerpt_overlaps: true,
  79        },
  80    };
  81
  82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  83    context: DEFAULT_CONTEXT_OPTIONS,
  84    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  85    max_diagnostic_bytes: 2048,
  86    prompt_format: PromptFormat::DEFAULT,
  87    file_indexing_parallelism: 1,
  88    buffer_change_grouping_interval: Duration::from_secs(1),
  89};
  90
  91static MODEL_ID: LazyLock<String> =
  92    LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()));
  93
  94pub struct Zeta2FeatureFlag;
  95
  96impl FeatureFlag for Zeta2FeatureFlag {
  97    const NAME: &'static str = "zeta2";
  98
  99    fn enabled_for_staff() -> bool {
 100        false
 101    }
 102}
 103
 104#[derive(Clone)]
 105struct ZetaGlobal(Entity<Zeta>);
 106
 107impl Global for ZetaGlobal {}
 108
 109pub struct Zeta {
 110    client: Arc<Client>,
 111    user_store: Entity<UserStore>,
 112    llm_token: LlmApiToken,
 113    _llm_token_subscription: Subscription,
 114    projects: HashMap<EntityId, ZetaProject>,
 115    options: ZetaOptions,
 116    update_required: bool,
 117    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 118}
 119
 120#[derive(Debug, Clone, PartialEq)]
 121pub struct ZetaOptions {
 122    pub context: ContextMode,
 123    pub max_prompt_bytes: usize,
 124    pub max_diagnostic_bytes: usize,
 125    pub prompt_format: predict_edits_v3::PromptFormat,
 126    pub file_indexing_parallelism: usize,
 127    pub buffer_change_grouping_interval: Duration,
 128}
 129
 130#[derive(Debug, Clone, PartialEq)]
 131pub enum ContextMode {
 132    Agentic(AgenticContextOptions),
 133    Syntax(EditPredictionContextOptions),
 134}
 135
 136#[derive(Debug, Clone, PartialEq)]
 137pub struct AgenticContextOptions {
 138    pub excerpt: EditPredictionExcerptOptions,
 139}
 140
 141impl ContextMode {
 142    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 143        match self {
 144            ContextMode::Agentic(options) => &options.excerpt,
 145            ContextMode::Syntax(options) => &options.excerpt,
 146        }
 147    }
 148}
 149
 150#[derive(Debug)]
 151pub enum ZetaDebugInfo {
 152    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 153    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 154    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 155    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 156    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 157}
 158
 159#[derive(Debug)]
 160pub struct ZetaContextRetrievalStartedDebugInfo {
 161    pub project: Entity<Project>,
 162    pub timestamp: Instant,
 163    pub search_prompt: String,
 164}
 165
 166#[derive(Debug)]
 167pub struct ZetaContextRetrievalDebugInfo {
 168    pub project: Entity<Project>,
 169    pub timestamp: Instant,
 170}
 171
 172#[derive(Debug)]
 173pub struct ZetaEditPredictionDebugInfo {
 174    pub request: predict_edits_v3::PredictEditsRequest,
 175    pub retrieval_time: TimeDelta,
 176    pub buffer: WeakEntity<Buffer>,
 177    pub position: language::Anchor,
 178    pub local_prompt: Result<String, String>,
 179    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 180}
 181
 182#[derive(Debug)]
 183pub struct ZetaSearchQueryDebugInfo {
 184    pub project: Entity<Project>,
 185    pub timestamp: Instant,
 186    pub search_queries: Vec<SearchToolQuery>,
 187}
 188
 189pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 190
 191struct ZetaProject {
 192    syntax_index: Entity<SyntaxIndex>,
 193    events: VecDeque<Event>,
 194    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 195    current_prediction: Option<CurrentEditPrediction>,
 196    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 197    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 198    refresh_context_debounce_task: Option<Task<Option<()>>>,
 199    refresh_context_timestamp: Option<Instant>,
 200}
 201
 202#[derive(Debug, Clone)]
 203struct CurrentEditPrediction {
 204    pub requested_by_buffer_id: EntityId,
 205    pub prediction: EditPrediction,
 206}
 207
 208impl CurrentEditPrediction {
 209    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 210        let Some(new_edits) = self
 211            .prediction
 212            .interpolate(&self.prediction.buffer.read(cx))
 213        else {
 214            return false;
 215        };
 216
 217        if self.prediction.buffer != old_prediction.prediction.buffer {
 218            return true;
 219        }
 220
 221        let Some(old_edits) = old_prediction
 222            .prediction
 223            .interpolate(&old_prediction.prediction.buffer.read(cx))
 224        else {
 225            return true;
 226        };
 227
 228        // This reduces the occurrence of UI thrash from replacing edits
 229        //
 230        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 231        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 232            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 233            && old_edits.len() == 1
 234            && new_edits.len() == 1
 235        {
 236            let (old_range, old_text) = &old_edits[0];
 237            let (new_range, new_text) = &new_edits[0];
 238            new_range == old_range && new_text.starts_with(old_text.as_ref())
 239        } else {
 240            true
 241        }
 242    }
 243}
 244
 245/// A prediction from the perspective of a buffer.
 246#[derive(Debug)]
 247enum BufferEditPrediction<'a> {
 248    Local { prediction: &'a EditPrediction },
 249    Jump { prediction: &'a EditPrediction },
 250}
 251
 252struct RegisteredBuffer {
 253    snapshot: BufferSnapshot,
 254    _subscriptions: [gpui::Subscription; 2],
 255}
 256
 257#[derive(Clone)]
 258pub enum Event {
 259    BufferChange {
 260        old_snapshot: BufferSnapshot,
 261        new_snapshot: BufferSnapshot,
 262        timestamp: Instant,
 263    },
 264}
 265
 266impl Event {
 267    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 268        match self {
 269            Event::BufferChange {
 270                old_snapshot,
 271                new_snapshot,
 272                ..
 273            } => {
 274                let path = new_snapshot.file().map(|f| f.full_path(cx));
 275
 276                let old_path = old_snapshot.file().and_then(|f| {
 277                    let old_path = f.full_path(cx);
 278                    if Some(&old_path) != path.as_ref() {
 279                        Some(old_path)
 280                    } else {
 281                        None
 282                    }
 283                });
 284
 285                // TODO [zeta2] move to bg?
 286                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 287
 288                if path == old_path && diff.is_empty() {
 289                    None
 290                } else {
 291                    Some(predict_edits_v3::Event::BufferChange {
 292                        old_path,
 293                        path,
 294                        diff,
 295                        //todo: Actually detect if this edit was predicted or not
 296                        predicted: false,
 297                    })
 298                }
 299            }
 300        }
 301    }
 302}
 303
 304impl Zeta {
 305    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 306        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 307    }
 308
 309    pub fn global(
 310        client: &Arc<Client>,
 311        user_store: &Entity<UserStore>,
 312        cx: &mut App,
 313    ) -> Entity<Self> {
 314        cx.try_global::<ZetaGlobal>()
 315            .map(|global| global.0.clone())
 316            .unwrap_or_else(|| {
 317                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 318                cx.set_global(ZetaGlobal(zeta.clone()));
 319                zeta
 320            })
 321    }
 322
 323    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 324        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 325
 326        Self {
 327            projects: HashMap::default(),
 328            client,
 329            user_store,
 330            options: DEFAULT_OPTIONS,
 331            llm_token: LlmApiToken::default(),
 332            _llm_token_subscription: cx.subscribe(
 333                &refresh_llm_token_listener,
 334                |this, _listener, _event, cx| {
 335                    let client = this.client.clone();
 336                    let llm_token = this.llm_token.clone();
 337                    cx.spawn(async move |_this, _cx| {
 338                        llm_token.refresh(&client).await?;
 339                        anyhow::Ok(())
 340                    })
 341                    .detach_and_log_err(cx);
 342                },
 343            ),
 344            update_required: false,
 345            debug_tx: None,
 346        }
 347    }
 348
 349    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 350        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 351        self.debug_tx = Some(debug_watch_tx);
 352        debug_watch_rx
 353    }
 354
 355    pub fn options(&self) -> &ZetaOptions {
 356        &self.options
 357    }
 358
 359    pub fn set_options(&mut self, options: ZetaOptions) {
 360        self.options = options;
 361    }
 362
 363    pub fn clear_history(&mut self) {
 364        for zeta_project in self.projects.values_mut() {
 365            zeta_project.events.clear();
 366        }
 367    }
 368
 369    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 370        self.projects
 371            .get(&project.entity_id())
 372            .map(|project| project.events.iter())
 373            .into_iter()
 374            .flatten()
 375    }
 376
 377    pub fn context_for_project(
 378        &self,
 379        project: &Entity<Project>,
 380    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 381        self.projects
 382            .get(&project.entity_id())
 383            .and_then(|project| {
 384                Some(
 385                    project
 386                        .context
 387                        .as_ref()?
 388                        .iter()
 389                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 390                )
 391            })
 392            .into_iter()
 393            .flatten()
 394    }
 395
 396    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 397        self.user_store.read(cx).edit_prediction_usage()
 398    }
 399
 400    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 401        self.get_or_init_zeta_project(project, cx);
 402    }
 403
 404    pub fn register_buffer(
 405        &mut self,
 406        buffer: &Entity<Buffer>,
 407        project: &Entity<Project>,
 408        cx: &mut Context<Self>,
 409    ) {
 410        let zeta_project = self.get_or_init_zeta_project(project, cx);
 411        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 412    }
 413
 414    fn get_or_init_zeta_project(
 415        &mut self,
 416        project: &Entity<Project>,
 417        cx: &mut App,
 418    ) -> &mut ZetaProject {
 419        self.projects
 420            .entry(project.entity_id())
 421            .or_insert_with(|| ZetaProject {
 422                syntax_index: cx.new(|cx| {
 423                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 424                }),
 425                events: VecDeque::new(),
 426                registered_buffers: HashMap::default(),
 427                current_prediction: None,
 428                context: None,
 429                refresh_context_task: None,
 430                refresh_context_debounce_task: None,
 431                refresh_context_timestamp: None,
 432            })
 433    }
 434
 435    fn register_buffer_impl<'a>(
 436        zeta_project: &'a mut ZetaProject,
 437        buffer: &Entity<Buffer>,
 438        project: &Entity<Project>,
 439        cx: &mut Context<Self>,
 440    ) -> &'a mut RegisteredBuffer {
 441        let buffer_id = buffer.entity_id();
 442        match zeta_project.registered_buffers.entry(buffer_id) {
 443            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 444            hash_map::Entry::Vacant(entry) => {
 445                let snapshot = buffer.read(cx).snapshot();
 446                let project_entity_id = project.entity_id();
 447                entry.insert(RegisteredBuffer {
 448                    snapshot,
 449                    _subscriptions: [
 450                        cx.subscribe(buffer, {
 451                            let project = project.downgrade();
 452                            move |this, buffer, event, cx| {
 453                                if let language::BufferEvent::Edited = event
 454                                    && let Some(project) = project.upgrade()
 455                                {
 456                                    this.report_changes_for_buffer(&buffer, &project, cx);
 457                                }
 458                            }
 459                        }),
 460                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 461                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 462                            else {
 463                                return;
 464                            };
 465                            zeta_project.registered_buffers.remove(&buffer_id);
 466                        }),
 467                    ],
 468                })
 469            }
 470        }
 471    }
 472
 473    fn report_changes_for_buffer(
 474        &mut self,
 475        buffer: &Entity<Buffer>,
 476        project: &Entity<Project>,
 477        cx: &mut Context<Self>,
 478    ) -> BufferSnapshot {
 479        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 480        let zeta_project = self.get_or_init_zeta_project(project, cx);
 481        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 482
 483        let new_snapshot = buffer.read(cx).snapshot();
 484        if new_snapshot.version != registered_buffer.snapshot.version {
 485            let old_snapshot =
 486                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 487            Self::push_event(
 488                zeta_project,
 489                buffer_change_grouping_interval,
 490                Event::BufferChange {
 491                    old_snapshot,
 492                    new_snapshot: new_snapshot.clone(),
 493                    timestamp: Instant::now(),
 494                },
 495            );
 496        }
 497
 498        new_snapshot
 499    }
 500
 501    fn push_event(
 502        zeta_project: &mut ZetaProject,
 503        buffer_change_grouping_interval: Duration,
 504        event: Event,
 505    ) {
 506        let events = &mut zeta_project.events;
 507
 508        if buffer_change_grouping_interval > Duration::ZERO
 509            && let Some(Event::BufferChange {
 510                new_snapshot: last_new_snapshot,
 511                timestamp: last_timestamp,
 512                ..
 513            }) = events.back_mut()
 514        {
 515            // Coalesce edits for the same buffer when they happen one after the other.
 516            let Event::BufferChange {
 517                old_snapshot,
 518                new_snapshot,
 519                timestamp,
 520            } = &event;
 521
 522            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 523                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 524                && old_snapshot.version == last_new_snapshot.version
 525            {
 526                *last_new_snapshot = new_snapshot.clone();
 527                *last_timestamp = *timestamp;
 528                return;
 529            }
 530        }
 531
 532        if events.len() >= MAX_EVENT_COUNT {
 533            // These are halved instead of popping to improve prompt caching.
 534            events.drain(..MAX_EVENT_COUNT / 2);
 535        }
 536
 537        events.push_back(event);
 538    }
 539
 540    fn current_prediction_for_buffer(
 541        &self,
 542        buffer: &Entity<Buffer>,
 543        project: &Entity<Project>,
 544        cx: &App,
 545    ) -> Option<BufferEditPrediction<'_>> {
 546        let project_state = self.projects.get(&project.entity_id())?;
 547
 548        let CurrentEditPrediction {
 549            requested_by_buffer_id,
 550            prediction,
 551        } = project_state.current_prediction.as_ref()?;
 552
 553        if prediction.targets_buffer(buffer.read(cx)) {
 554            Some(BufferEditPrediction::Local { prediction })
 555        } else if *requested_by_buffer_id == buffer.entity_id() {
 556            Some(BufferEditPrediction::Jump { prediction })
 557        } else {
 558            None
 559        }
 560    }
 561
 562    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 563        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 564            return;
 565        };
 566
 567        let Some(prediction) = project_state.current_prediction.take() else {
 568            return;
 569        };
 570        let request_id = prediction.prediction.id.into();
 571
 572        let client = self.client.clone();
 573        let llm_token = self.llm_token.clone();
 574        let app_version = AppVersion::global(cx);
 575        cx.spawn(async move |this, cx| {
 576            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 577                http_client::Url::parse(&predict_edits_url)?
 578            } else {
 579                client
 580                    .http_client()
 581                    .build_zed_llm_url("/predict_edits/accept", &[])?
 582            };
 583
 584            let response = cx
 585                .background_spawn(Self::send_api_request::<()>(
 586                    move |builder| {
 587                        let req = builder.uri(url.as_ref()).body(
 588                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 589                        );
 590                        Ok(req?)
 591                    },
 592                    client,
 593                    llm_token,
 594                    app_version,
 595                ))
 596                .await;
 597
 598            Self::handle_api_response(&this, response, cx)?;
 599            anyhow::Ok(())
 600        })
 601        .detach_and_log_err(cx);
 602    }
 603
 604    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 605        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 606            project_state.current_prediction.take();
 607        };
 608    }
 609
 610    pub fn refresh_prediction(
 611        &mut self,
 612        project: &Entity<Project>,
 613        buffer: &Entity<Buffer>,
 614        position: language::Anchor,
 615        cx: &mut Context<Self>,
 616    ) -> Task<Result<()>> {
 617        let request_task = self.request_prediction(project, buffer, position, cx);
 618        let buffer = buffer.clone();
 619        let project = project.clone();
 620
 621        cx.spawn(async move |this, cx| {
 622            if let Some(prediction) = request_task.await? {
 623                this.update(cx, |this, cx| {
 624                    let project_state = this
 625                        .projects
 626                        .get_mut(&project.entity_id())
 627                        .context("Project not found")?;
 628
 629                    let new_prediction = CurrentEditPrediction {
 630                        requested_by_buffer_id: buffer.entity_id(),
 631                        prediction: prediction,
 632                    };
 633
 634                    if project_state
 635                        .current_prediction
 636                        .as_ref()
 637                        .is_none_or(|old_prediction| {
 638                            new_prediction.should_replace_prediction(&old_prediction, cx)
 639                        })
 640                    {
 641                        project_state.current_prediction = Some(new_prediction);
 642                    }
 643                    anyhow::Ok(())
 644                })??;
 645            }
 646            Ok(())
 647        })
 648    }
 649
 650    pub fn request_prediction(
 651        &mut self,
 652        project: &Entity<Project>,
 653        active_buffer: &Entity<Buffer>,
 654        position: language::Anchor,
 655        cx: &mut Context<Self>,
 656    ) -> Task<Result<Option<EditPrediction>>> {
 657        let project_state = self.projects.get(&project.entity_id());
 658
 659        let index_state = project_state.map(|state| {
 660            state
 661                .syntax_index
 662                .read_with(cx, |index, _cx| index.state().clone())
 663        });
 664        let options = self.options.clone();
 665        let active_snapshot = active_buffer.read(cx).snapshot();
 666        let Some(excerpt_path) = active_snapshot
 667            .file()
 668            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 669        else {
 670            return Task::ready(Err(anyhow!("No file path for excerpt")));
 671        };
 672        let client = self.client.clone();
 673        let llm_token = self.llm_token.clone();
 674        let app_version = AppVersion::global(cx);
 675        let worktree_snapshots = project
 676            .read(cx)
 677            .worktrees(cx)
 678            .map(|worktree| worktree.read(cx).snapshot())
 679            .collect::<Vec<_>>();
 680        let debug_tx = self.debug_tx.clone();
 681
 682        let events = project_state
 683            .map(|state| {
 684                state
 685                    .events
 686                    .iter()
 687                    .filter_map(|event| event.to_request_event(cx))
 688                    .collect::<Vec<_>>()
 689            })
 690            .unwrap_or_default();
 691
 692        let diagnostics = active_snapshot.diagnostic_sets().clone();
 693
 694        let parent_abs_path =
 695            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 696                let mut path = f.worktree.read(cx).absolutize(&f.path);
 697                if path.pop() { Some(path) } else { None }
 698            });
 699
 700        // TODO data collection
 701        let can_collect_data = cx.is_staff();
 702
 703        let mut included_files = project_state
 704            .and_then(|project_state| project_state.context.as_ref())
 705            .unwrap_or(&HashMap::default())
 706            .iter()
 707            .filter_map(|(buffer_entity, ranges)| {
 708                let buffer = buffer_entity.read(cx);
 709                Some((
 710                    buffer_entity.clone(),
 711                    buffer.snapshot(),
 712                    buffer.file()?.full_path(cx).into(),
 713                    ranges.clone(),
 714                ))
 715            })
 716            .collect::<Vec<_>>();
 717
 718        let request_task = cx.background_spawn({
 719            let active_buffer = active_buffer.clone();
 720            async move {
 721                let index_state = if let Some(index_state) = index_state {
 722                    Some(index_state.lock_owned().await)
 723                } else {
 724                    None
 725                };
 726
 727                let cursor_offset = position.to_offset(&active_snapshot);
 728                let cursor_point = cursor_offset.to_point(&active_snapshot);
 729
 730                let before_retrieval = chrono::Utc::now();
 731
 732                let (diagnostic_groups, diagnostic_groups_truncated) =
 733                    Self::gather_nearby_diagnostics(
 734                        cursor_offset,
 735                        &diagnostics,
 736                        &active_snapshot,
 737                        options.max_diagnostic_bytes,
 738                    );
 739
 740                let cloud_request = match options.context {
 741                    ContextMode::Agentic(context_options) => {
 742                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 743                            cursor_point,
 744                            &active_snapshot,
 745                            &context_options.excerpt,
 746                            index_state.as_deref(),
 747                        ) else {
 748                            return Ok((None, None));
 749                        };
 750
 751                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 752                            ..active_snapshot.anchor_before(excerpt.range.end);
 753
 754                        if let Some(buffer_ix) =
 755                            included_files.iter().position(|(_, snapshot, _, _)| {
 756                                snapshot.remote_id() == active_snapshot.remote_id()
 757                            })
 758                        {
 759                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 760                            let range_ix = ranges
 761                                .binary_search_by(|probe| {
 762                                    probe
 763                                        .start
 764                                        .cmp(&excerpt_anchor_range.start, buffer)
 765                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 766                                })
 767                                .unwrap_or_else(|ix| ix);
 768
 769                            ranges.insert(range_ix, excerpt_anchor_range);
 770                            let last_ix = included_files.len() - 1;
 771                            included_files.swap(buffer_ix, last_ix);
 772                        } else {
 773                            included_files.push((
 774                                active_buffer.clone(),
 775                                active_snapshot,
 776                                excerpt_path.clone(),
 777                                vec![excerpt_anchor_range],
 778                            ));
 779                        }
 780
 781                        let included_files = included_files
 782                            .iter()
 783                            .map(|(_, buffer, path, ranges)| {
 784                                let excerpts = merge_excerpts(
 785                                    &buffer,
 786                                    ranges.iter().map(|range| {
 787                                        let point_range = range.to_point(&buffer);
 788                                        Line(point_range.start.row)..Line(point_range.end.row)
 789                                    }),
 790                                );
 791                                predict_edits_v3::IncludedFile {
 792                                    path: path.clone(),
 793                                    max_row: Line(buffer.max_point().row),
 794                                    excerpts,
 795                                }
 796                            })
 797                            .collect::<Vec<_>>();
 798
 799                        predict_edits_v3::PredictEditsRequest {
 800                            excerpt_path,
 801                            excerpt: String::new(),
 802                            excerpt_line_range: Line(0)..Line(0),
 803                            excerpt_range: 0..0,
 804                            cursor_point: predict_edits_v3::Point {
 805                                line: predict_edits_v3::Line(cursor_point.row),
 806                                column: cursor_point.column,
 807                            },
 808                            included_files,
 809                            referenced_declarations: vec![],
 810                            events,
 811                            can_collect_data,
 812                            diagnostic_groups,
 813                            diagnostic_groups_truncated,
 814                            debug_info: debug_tx.is_some(),
 815                            prompt_max_bytes: Some(options.max_prompt_bytes),
 816                            prompt_format: options.prompt_format,
 817                            // TODO [zeta2]
 818                            signatures: vec![],
 819                            excerpt_parent: None,
 820                            git_info: None,
 821                        }
 822                    }
 823                    ContextMode::Syntax(context_options) => {
 824                        let Some(context) = EditPredictionContext::gather_context(
 825                            cursor_point,
 826                            &active_snapshot,
 827                            parent_abs_path.as_deref(),
 828                            &context_options,
 829                            index_state.as_deref(),
 830                        ) else {
 831                            return Ok((None, None));
 832                        };
 833
 834                        make_syntax_context_cloud_request(
 835                            excerpt_path,
 836                            context,
 837                            events,
 838                            can_collect_data,
 839                            diagnostic_groups,
 840                            diagnostic_groups_truncated,
 841                            None,
 842                            debug_tx.is_some(),
 843                            &worktree_snapshots,
 844                            index_state.as_deref(),
 845                            Some(options.max_prompt_bytes),
 846                            options.prompt_format,
 847                        )
 848                    }
 849                };
 850
 851                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 852
 853                let retrieval_time = chrono::Utc::now() - before_retrieval;
 854
 855                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 856                    let (response_tx, response_rx) = oneshot::channel();
 857
 858                    debug_tx
 859                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 860                            ZetaEditPredictionDebugInfo {
 861                                request: cloud_request.clone(),
 862                                retrieval_time,
 863                                buffer: active_buffer.downgrade(),
 864                                local_prompt: match prompt_result.as_ref() {
 865                                    Ok((prompt, _)) => Ok(prompt.clone()),
 866                                    Err(err) => Err(err.to_string()),
 867                                },
 868                                position,
 869                                response_rx,
 870                            },
 871                        ))
 872                        .ok();
 873                    Some(response_tx)
 874                } else {
 875                    None
 876                };
 877
 878                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 879                    if let Some(debug_response_tx) = debug_response_tx {
 880                        debug_response_tx
 881                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 882                            .ok();
 883                    }
 884                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 885                }
 886
 887                let (prompt, _) = prompt_result?;
 888                let request = open_ai::Request {
 889                    model: MODEL_ID.clone(),
 890                    messages: vec![open_ai::RequestMessage::User {
 891                        content: open_ai::MessageContent::Plain(prompt),
 892                    }],
 893                    stream: false,
 894                    max_completion_tokens: None,
 895                    stop: Default::default(),
 896                    temperature: 0.7,
 897                    tool_choice: None,
 898                    parallel_tool_calls: None,
 899                    tools: vec![],
 900                    prompt_cache_key: None,
 901                    reasoning_effort: None,
 902                };
 903
 904                log::trace!("Sending edit prediction request");
 905
 906                let before_request = chrono::Utc::now();
 907                let response =
 908                    Self::send_raw_llm_request(client, llm_token, app_version, request).await;
 909                let request_time = chrono::Utc::now() - before_request;
 910
 911                log::trace!("Got edit prediction response");
 912
 913                if let Some(debug_response_tx) = debug_response_tx {
 914                    debug_response_tx
 915                        .send((
 916                            response
 917                                .as_ref()
 918                                .map_err(|err| err.to_string())
 919                                .map(|response| response.0.clone()),
 920                            request_time,
 921                        ))
 922                        .ok();
 923                }
 924
 925                let (res, usage) = response?;
 926                let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
 927                let Some(output_text) = text_from_response(res) else {
 928                    return Ok((None, usage))
 929                };
 930
 931                let (edited_buffer_snapshot, edits) =
 932                    crate::udiff::parse_diff(&output_text, |path| {
 933                        included_files
 934                            .iter()
 935                            .find_map(|(_, buffer, probe_path, ranges)| {
 936                                if probe_path.as_ref() == path {
 937                                    Some((buffer, ranges.as_slice()))
 938                                } else {
 939                                    None
 940                                }
 941                            })
 942                    })
 943                    .await?;
 944
 945                let edited_buffer = included_files
 946                    .iter()
 947                    .find_map(|(buffer, snapshot, _, _)| {
 948                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
 949                            Some(buffer.clone())
 950                        } else {
 951                            None
 952                        }
 953                    })
 954                    .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
 955
 956                anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
 957            }
 958        });
 959
 960        cx.spawn({
 961            async move |this, cx| {
 962                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
 963                    Self::handle_api_response(&this, request_task.await, cx)?
 964                else {
 965                    return Ok(None);
 966                };
 967
 968                // TODO telemetry: duration, etc
 969                Ok(
 970                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
 971                        .await,
 972                )
 973            }
 974        })
 975    }
 976
 977    async fn send_raw_llm_request(
 978        client: Arc<Client>,
 979        llm_token: LlmApiToken,
 980        app_version: SemanticVersion,
 981        request: open_ai::Request,
 982    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
 983        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 984            http_client::Url::parse(&predict_edits_url)?
 985        } else {
 986            client
 987                .http_client()
 988                .build_zed_llm_url("/predict_edits/raw", &[])?
 989        };
 990
 991        Self::send_api_request(
 992            |builder| {
 993                let req = builder
 994                    .uri(url.as_ref())
 995                    .body(serde_json::to_string(&request)?.into());
 996                Ok(req?)
 997            },
 998            client,
 999            llm_token,
1000            app_version,
1001        )
1002        .await
1003    }
1004
1005    fn handle_api_response<T>(
1006        this: &WeakEntity<Self>,
1007        response: Result<(T, Option<EditPredictionUsage>)>,
1008        cx: &mut gpui::AsyncApp,
1009    ) -> Result<T> {
1010        match response {
1011            Ok((data, usage)) => {
1012                if let Some(usage) = usage {
1013                    this.update(cx, |this, cx| {
1014                        this.user_store.update(cx, |user_store, cx| {
1015                            user_store.update_edit_prediction_usage(usage, cx);
1016                        });
1017                    })
1018                    .ok();
1019                }
1020                Ok(data)
1021            }
1022            Err(err) => {
1023                if err.is::<ZedUpdateRequiredError>() {
1024                    cx.update(|cx| {
1025                        this.update(cx, |this, _cx| {
1026                            this.update_required = true;
1027                        })
1028                        .ok();
1029
1030                        let error_message: SharedString = err.to_string().into();
1031                        show_app_notification(
1032                            NotificationId::unique::<ZedUpdateRequiredError>(),
1033                            cx,
1034                            move |cx| {
1035                                cx.new(|cx| {
1036                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1037                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1038                                })
1039                            },
1040                        );
1041                    })
1042                    .ok();
1043                }
1044                Err(err)
1045            }
1046        }
1047    }
1048
1049    async fn send_api_request<Res>(
1050        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1051        client: Arc<Client>,
1052        llm_token: LlmApiToken,
1053        app_version: SemanticVersion,
1054    ) -> Result<(Res, Option<EditPredictionUsage>)>
1055    where
1056        Res: DeserializeOwned,
1057    {
1058        let http_client = client.http_client();
1059        let mut token = llm_token.acquire(&client).await?;
1060        let mut did_retry = false;
1061
1062        loop {
1063            let request_builder = http_client::Request::builder().method(Method::POST);
1064
1065            let request = build(
1066                request_builder
1067                    .header("Content-Type", "application/json")
1068                    .header("Authorization", format!("Bearer {}", token))
1069                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1070            )?;
1071
1072            let mut response = http_client.send(request).await?;
1073
1074            if let Some(minimum_required_version) = response
1075                .headers()
1076                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1077                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1078            {
1079                anyhow::ensure!(
1080                    app_version >= minimum_required_version,
1081                    ZedUpdateRequiredError {
1082                        minimum_version: minimum_required_version
1083                    }
1084                );
1085            }
1086
1087            if response.status().is_success() {
1088                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1089
1090                let mut body = Vec::new();
1091                response.body_mut().read_to_end(&mut body).await?;
1092                return Ok((serde_json::from_slice(&body)?, usage));
1093            } else if !did_retry
1094                && response
1095                    .headers()
1096                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1097                    .is_some()
1098            {
1099                did_retry = true;
1100                token = llm_token.refresh(&client).await?;
1101            } else {
1102                let mut body = String::new();
1103                response.body_mut().read_to_string(&mut body).await?;
1104                anyhow::bail!(
1105                    "Request failed with status: {:?}\nBody: {}",
1106                    response.status(),
1107                    body
1108                );
1109            }
1110        }
1111    }
1112
1113    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1114    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1115
1116    // Refresh the related excerpts when the user just beguns editing after
1117    // an idle period, and after they pause editing.
1118    fn refresh_context_if_needed(
1119        &mut self,
1120        project: &Entity<Project>,
1121        buffer: &Entity<language::Buffer>,
1122        cursor_position: language::Anchor,
1123        cx: &mut Context<Self>,
1124    ) {
1125        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1126            return;
1127        }
1128
1129        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1130            return;
1131        };
1132
1133        let now = Instant::now();
1134        let was_idle = zeta_project
1135            .refresh_context_timestamp
1136            .map_or(true, |timestamp| {
1137                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1138            });
1139        zeta_project.refresh_context_timestamp = Some(now);
1140        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1141            let buffer = buffer.clone();
1142            let project = project.clone();
1143            async move |this, cx| {
1144                if was_idle {
1145                    log::debug!("refetching edit prediction context after idle");
1146                } else {
1147                    cx.background_executor()
1148                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1149                        .await;
1150                    log::debug!("refetching edit prediction context after pause");
1151                }
1152                this.update(cx, |this, cx| {
1153                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1154
1155                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1156                        zeta_project.refresh_context_task = Some(task.log_err());
1157                    };
1158                })
1159                .ok()
1160            }
1161        }));
1162    }
1163
1164    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1165    // and avoid spawning more than one concurrent task.
1166    pub fn refresh_context(
1167        &mut self,
1168        project: Entity<Project>,
1169        buffer: Entity<language::Buffer>,
1170        cursor_position: language::Anchor,
1171        cx: &mut Context<Self>,
1172    ) -> Task<Result<()>> {
1173        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1174            return Task::ready(anyhow::Ok(()));
1175        };
1176
1177        let ContextMode::Agentic(options) = &self.options().context else {
1178            return Task::ready(anyhow::Ok(()));
1179        };
1180
1181        let snapshot = buffer.read(cx).snapshot();
1182        let cursor_point = cursor_position.to_point(&snapshot);
1183        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1184            cursor_point,
1185            &snapshot,
1186            &options.excerpt,
1187            None,
1188        ) else {
1189            return Task::ready(Ok(()));
1190        };
1191
1192        let app_version = AppVersion::global(cx);
1193        let client = self.client.clone();
1194        let llm_token = self.llm_token.clone();
1195        let debug_tx = self.debug_tx.clone();
1196        let current_file_path: Arc<Path> = snapshot
1197            .file()
1198            .map(|f| f.full_path(cx).into())
1199            .unwrap_or_else(|| Path::new("untitled").into());
1200
1201        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1202            predict_edits_v3::PlanContextRetrievalRequest {
1203                excerpt: cursor_excerpt.text(&snapshot).body,
1204                excerpt_path: current_file_path,
1205                excerpt_line_range: cursor_excerpt.line_range,
1206                cursor_file_max_row: Line(snapshot.max_point().row),
1207                events: zeta_project
1208                    .events
1209                    .iter()
1210                    .filter_map(|ev| ev.to_request_event(cx))
1211                    .collect(),
1212            },
1213        ) {
1214            Ok(prompt) => prompt,
1215            Err(err) => {
1216                return Task::ready(Err(err));
1217            }
1218        };
1219
1220        if let Some(debug_tx) = &debug_tx {
1221            debug_tx
1222                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1223                    ZetaContextRetrievalStartedDebugInfo {
1224                        project: project.clone(),
1225                        timestamp: Instant::now(),
1226                        search_prompt: prompt.clone(),
1227                    },
1228                ))
1229                .ok();
1230        }
1231
1232        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1233            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1234                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1235            );
1236
1237            let description = schema
1238                .get("description")
1239                .and_then(|description| description.as_str())
1240                .unwrap()
1241                .to_string();
1242
1243            (schema.into(), description)
1244        });
1245
1246        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1247
1248        let request = open_ai::Request {
1249            model: MODEL_ID.clone(),
1250            messages: vec![open_ai::RequestMessage::User {
1251                content: open_ai::MessageContent::Plain(prompt),
1252            }],
1253            stream: false,
1254            max_completion_tokens: None,
1255            stop: Default::default(),
1256            temperature: 0.7,
1257            tool_choice: None,
1258            parallel_tool_calls: None,
1259            tools: vec![open_ai::ToolDefinition::Function {
1260                function: FunctionDefinition {
1261                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1262                    description: Some(tool_description),
1263                    parameters: Some(tool_schema),
1264                },
1265            }],
1266            prompt_cache_key: None,
1267            reasoning_effort: None,
1268        };
1269
1270        cx.spawn(async move |this, cx| {
1271            log::trace!("Sending search planning request");
1272            let response =
1273                Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1274            let mut response = Self::handle_api_response(&this, response, cx)?;
1275            log::trace!("Got search planning response");
1276
1277            let choice = response
1278                .choices
1279                .pop()
1280                .context("No choices in retrieval response")?;
1281            let open_ai::RequestMessage::Assistant {
1282                content: _,
1283                tool_calls,
1284            } = choice.message
1285            else {
1286                anyhow::bail!("Retrieval response didn't include an assistant message");
1287            };
1288
1289            let mut queries: Vec<SearchToolQuery> = Vec::new();
1290            for tool_call in tool_calls {
1291                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1292                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1293                    log::warn!(
1294                        "Context retrieval response tried to call an unknown tool: {}",
1295                        function.name
1296                    );
1297
1298                    continue;
1299                }
1300
1301                let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1302                queries.extend(input.queries);
1303            }
1304
1305            if let Some(debug_tx) = &debug_tx {
1306                debug_tx
1307                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1308                        ZetaSearchQueryDebugInfo {
1309                            project: project.clone(),
1310                            timestamp: Instant::now(),
1311                            search_queries: queries.clone(),
1312                        },
1313                    ))
1314                    .ok();
1315            }
1316
1317            log::trace!("Running retrieval search: {queries:#?}");
1318
1319            let related_excerpts_result =
1320                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1321
1322            log::trace!("Search queries executed");
1323
1324            if let Some(debug_tx) = &debug_tx {
1325                debug_tx
1326                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1327                        ZetaContextRetrievalDebugInfo {
1328                            project: project.clone(),
1329                            timestamp: Instant::now(),
1330                        },
1331                    ))
1332                    .ok();
1333            }
1334
1335            this.update(cx, |this, _cx| {
1336                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1337                    return Ok(());
1338                };
1339                zeta_project.refresh_context_task.take();
1340                if let Some(debug_tx) = &this.debug_tx {
1341                    debug_tx
1342                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1343                            ZetaContextRetrievalDebugInfo {
1344                                project,
1345                                timestamp: Instant::now(),
1346                            },
1347                        ))
1348                        .ok();
1349                }
1350                match related_excerpts_result {
1351                    Ok(excerpts) => {
1352                        zeta_project.context = Some(excerpts);
1353                        Ok(())
1354                    }
1355                    Err(error) => Err(error),
1356                }
1357            })?
1358        })
1359    }
1360
1361    fn gather_nearby_diagnostics(
1362        cursor_offset: usize,
1363        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1364        snapshot: &BufferSnapshot,
1365        max_diagnostics_bytes: usize,
1366    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1367        // TODO: Could make this more efficient
1368        let mut diagnostic_groups = Vec::new();
1369        for (language_server_id, diagnostics) in diagnostic_sets {
1370            let mut groups = Vec::new();
1371            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1372            diagnostic_groups.extend(
1373                groups
1374                    .into_iter()
1375                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1376            );
1377        }
1378
1379        // sort by proximity to cursor
1380        diagnostic_groups.sort_by_key(|group| {
1381            let range = &group.entries[group.primary_ix].range;
1382            if range.start >= cursor_offset {
1383                range.start - cursor_offset
1384            } else if cursor_offset >= range.end {
1385                cursor_offset - range.end
1386            } else {
1387                (cursor_offset - range.start).min(range.end - cursor_offset)
1388            }
1389        });
1390
1391        let mut results = Vec::new();
1392        let mut diagnostic_groups_truncated = false;
1393        let mut diagnostics_byte_count = 0;
1394        for group in diagnostic_groups {
1395            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1396            diagnostics_byte_count += raw_value.get().len();
1397            if diagnostics_byte_count > max_diagnostics_bytes {
1398                diagnostic_groups_truncated = true;
1399                break;
1400            }
1401            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1402        }
1403
1404        (results, diagnostic_groups_truncated)
1405    }
1406
1407    // TODO: Dedupe with similar code in request_prediction?
1408    pub fn cloud_request_for_zeta_cli(
1409        &mut self,
1410        project: &Entity<Project>,
1411        buffer: &Entity<Buffer>,
1412        position: language::Anchor,
1413        cx: &mut Context<Self>,
1414    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1415        let project_state = self.projects.get(&project.entity_id());
1416
1417        let index_state = project_state.map(|state| {
1418            state
1419                .syntax_index
1420                .read_with(cx, |index, _cx| index.state().clone())
1421        });
1422        let options = self.options.clone();
1423        let snapshot = buffer.read(cx).snapshot();
1424        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1425            return Task::ready(Err(anyhow!("No file path for excerpt")));
1426        };
1427        let worktree_snapshots = project
1428            .read(cx)
1429            .worktrees(cx)
1430            .map(|worktree| worktree.read(cx).snapshot())
1431            .collect::<Vec<_>>();
1432
1433        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1434            let mut path = f.worktree.read(cx).absolutize(&f.path);
1435            if path.pop() { Some(path) } else { None }
1436        });
1437
1438        cx.background_spawn(async move {
1439            let index_state = if let Some(index_state) = index_state {
1440                Some(index_state.lock_owned().await)
1441            } else {
1442                None
1443            };
1444
1445            let cursor_point = position.to_point(&snapshot);
1446
1447            let debug_info = true;
1448            EditPredictionContext::gather_context(
1449                cursor_point,
1450                &snapshot,
1451                parent_abs_path.as_deref(),
1452                match &options.context {
1453                    ContextMode::Agentic(_) => {
1454                        // TODO
1455                        panic!("Llm mode not supported in zeta cli yet");
1456                    }
1457                    ContextMode::Syntax(edit_prediction_context_options) => {
1458                        edit_prediction_context_options
1459                    }
1460                },
1461                index_state.as_deref(),
1462            )
1463            .context("Failed to select excerpt")
1464            .map(|context| {
1465                make_syntax_context_cloud_request(
1466                    excerpt_path.into(),
1467                    context,
1468                    // TODO pass everything
1469                    Vec::new(),
1470                    false,
1471                    Vec::new(),
1472                    false,
1473                    None,
1474                    debug_info,
1475                    &worktree_snapshots,
1476                    index_state.as_deref(),
1477                    Some(options.max_prompt_bytes),
1478                    options.prompt_format,
1479                )
1480            })
1481        })
1482    }
1483
1484    pub fn wait_for_initial_indexing(
1485        &mut self,
1486        project: &Entity<Project>,
1487        cx: &mut App,
1488    ) -> Task<Result<()>> {
1489        let zeta_project = self.get_or_init_zeta_project(project, cx);
1490        zeta_project
1491            .syntax_index
1492            .read(cx)
1493            .wait_for_initial_file_indexing(cx)
1494    }
1495}
1496
1497pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1498    let choice = res.choices.pop()?;
1499    let output_text = match choice.message {
1500        open_ai::RequestMessage::Assistant {
1501            content: Some(open_ai::MessageContent::Plain(content)),
1502            ..
1503        } => content,
1504        open_ai::RequestMessage::Assistant {
1505            content: Some(open_ai::MessageContent::Multipart(mut content)),
1506            ..
1507        } => {
1508            if content.is_empty() {
1509                log::error!("No output from Baseten completion response");
1510                return None;
1511            }
1512
1513            match content.remove(0) {
1514                open_ai::MessagePart::Text { text } => text,
1515                open_ai::MessagePart::Image { .. } => {
1516                    log::error!("Expected text, got an image");
1517                    return None;
1518                }
1519            }
1520        }
1521        _ => {
1522            log::error!("Invalid response message: {:?}", choice.message);
1523            return None;
1524        }
1525    };
1526    Some(output_text)
1527}
1528
1529#[derive(Error, Debug)]
1530#[error(
1531    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1532)]
1533pub struct ZedUpdateRequiredError {
1534    minimum_version: SemanticVersion,
1535}
1536
1537fn make_syntax_context_cloud_request(
1538    excerpt_path: Arc<Path>,
1539    context: EditPredictionContext,
1540    events: Vec<predict_edits_v3::Event>,
1541    can_collect_data: bool,
1542    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1543    diagnostic_groups_truncated: bool,
1544    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1545    debug_info: bool,
1546    worktrees: &Vec<worktree::Snapshot>,
1547    index_state: Option<&SyntaxIndexState>,
1548    prompt_max_bytes: Option<usize>,
1549    prompt_format: PromptFormat,
1550) -> predict_edits_v3::PredictEditsRequest {
1551    let mut signatures = Vec::new();
1552    let mut declaration_to_signature_index = HashMap::default();
1553    let mut referenced_declarations = Vec::new();
1554
1555    for snippet in context.declarations {
1556        let project_entry_id = snippet.declaration.project_entry_id();
1557        let Some(path) = worktrees.iter().find_map(|worktree| {
1558            worktree.entry_for_id(project_entry_id).map(|entry| {
1559                let mut full_path = RelPathBuf::new();
1560                full_path.push(worktree.root_name());
1561                full_path.push(&entry.path);
1562                full_path
1563            })
1564        }) else {
1565            continue;
1566        };
1567
1568        let parent_index = index_state.and_then(|index_state| {
1569            snippet.declaration.parent().and_then(|parent| {
1570                add_signature(
1571                    parent,
1572                    &mut declaration_to_signature_index,
1573                    &mut signatures,
1574                    index_state,
1575                )
1576            })
1577        });
1578
1579        let (text, text_is_truncated) = snippet.declaration.item_text();
1580        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1581            path: path.as_std_path().into(),
1582            text: text.into(),
1583            range: snippet.declaration.item_line_range(),
1584            text_is_truncated,
1585            signature_range: snippet.declaration.signature_range_in_item_text(),
1586            parent_index,
1587            signature_score: snippet.score(DeclarationStyle::Signature),
1588            declaration_score: snippet.score(DeclarationStyle::Declaration),
1589            score_components: snippet.components,
1590        });
1591    }
1592
1593    let excerpt_parent = index_state.and_then(|index_state| {
1594        context
1595            .excerpt
1596            .parent_declarations
1597            .last()
1598            .and_then(|(parent, _)| {
1599                add_signature(
1600                    *parent,
1601                    &mut declaration_to_signature_index,
1602                    &mut signatures,
1603                    index_state,
1604                )
1605            })
1606    });
1607
1608    predict_edits_v3::PredictEditsRequest {
1609        excerpt_path,
1610        excerpt: context.excerpt_text.body,
1611        excerpt_line_range: context.excerpt.line_range,
1612        excerpt_range: context.excerpt.range,
1613        cursor_point: predict_edits_v3::Point {
1614            line: predict_edits_v3::Line(context.cursor_point.row),
1615            column: context.cursor_point.column,
1616        },
1617        referenced_declarations,
1618        included_files: vec![],
1619        signatures,
1620        excerpt_parent,
1621        events,
1622        can_collect_data,
1623        diagnostic_groups,
1624        diagnostic_groups_truncated,
1625        git_info,
1626        debug_info,
1627        prompt_max_bytes,
1628        prompt_format,
1629    }
1630}
1631
1632fn add_signature(
1633    declaration_id: DeclarationId,
1634    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1635    signatures: &mut Vec<Signature>,
1636    index: &SyntaxIndexState,
1637) -> Option<usize> {
1638    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1639        return Some(*signature_index);
1640    }
1641    let Some(parent_declaration) = index.declaration(declaration_id) else {
1642        log::error!("bug: missing parent declaration");
1643        return None;
1644    };
1645    let parent_index = parent_declaration.parent().and_then(|parent| {
1646        add_signature(parent, declaration_to_signature_index, signatures, index)
1647    });
1648    let (text, text_is_truncated) = parent_declaration.signature_text();
1649    let signature_index = signatures.len();
1650    signatures.push(Signature {
1651        text: text.into(),
1652        text_is_truncated,
1653        parent_index,
1654        range: parent_declaration.signature_line_range(),
1655    });
1656    declaration_to_signature_index.insert(declaration_id, signature_index);
1657    Some(signature_index)
1658}
1659
1660#[cfg(test)]
1661mod tests {
1662    use std::{path::Path, sync::Arc};
1663
1664    use client::UserStore;
1665    use clock::FakeSystemClock;
1666    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1667    use futures::{
1668        AsyncReadExt, StreamExt,
1669        channel::{mpsc, oneshot},
1670    };
1671    use gpui::{
1672        Entity, TestAppContext,
1673        http_client::{FakeHttpClient, Response},
1674        prelude::*,
1675    };
1676    use indoc::indoc;
1677    use language::OffsetRangeExt as _;
1678    use open_ai::Usage;
1679    use pretty_assertions::{assert_eq, assert_matches};
1680    use project::{FakeFs, Project};
1681    use serde_json::json;
1682    use settings::SettingsStore;
1683    use util::path;
1684    use uuid::Uuid;
1685
1686    use crate::{BufferEditPrediction, Zeta};
1687
1688    #[gpui::test]
1689    async fn test_current_state(cx: &mut TestAppContext) {
1690        let (zeta, mut req_rx) = init_test(cx);
1691        let fs = FakeFs::new(cx.executor());
1692        fs.insert_tree(
1693            "/root",
1694            json!({
1695                "1.txt": "Hello!\nHow\nBye\n",
1696                "2.txt": "Hola!\nComo\nAdios\n"
1697            }),
1698        )
1699        .await;
1700        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1701
1702        zeta.update(cx, |zeta, cx| {
1703            zeta.register_project(&project, cx);
1704        });
1705
1706        let buffer1 = project
1707            .update(cx, |project, cx| {
1708                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1709                project.open_buffer(path, cx)
1710            })
1711            .await
1712            .unwrap();
1713        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1714        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1715
1716        // Prediction for current file
1717
1718        let prediction_task = zeta.update(cx, |zeta, cx| {
1719            zeta.refresh_prediction(&project, &buffer1, position, cx)
1720        });
1721        let (_request, respond_tx) = req_rx.next().await.unwrap();
1722
1723        respond_tx
1724            .send(model_response(indoc! {r"
1725                --- a/root/1.txt
1726                +++ b/root/1.txt
1727                @@ ... @@
1728                 Hello!
1729                -How
1730                +How are you?
1731                 Bye
1732            "}))
1733            .unwrap();
1734        prediction_task.await.unwrap();
1735
1736        zeta.read_with(cx, |zeta, cx| {
1737            let prediction = zeta
1738                .current_prediction_for_buffer(&buffer1, &project, cx)
1739                .unwrap();
1740            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1741        });
1742
1743        // Context refresh
1744        let refresh_task = zeta.update(cx, |zeta, cx| {
1745            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1746        });
1747        let (_request, respond_tx) = req_rx.next().await.unwrap();
1748        respond_tx
1749            .send(open_ai::Response {
1750                id: Uuid::new_v4().to_string(),
1751                object: "response".into(),
1752                created: 0,
1753                model: "model".into(),
1754                choices: vec![open_ai::Choice {
1755                    index: 0,
1756                    message: open_ai::RequestMessage::Assistant {
1757                        content: None,
1758                        tool_calls: vec![open_ai::ToolCall {
1759                            id: "search".into(),
1760                            content: open_ai::ToolCallContent::Function {
1761                                function: open_ai::FunctionContent {
1762                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1763                                        .to_string(),
1764                                    arguments: serde_json::to_string(&SearchToolInput {
1765                                        queries: Box::new([SearchToolQuery {
1766                                            glob: "root/2.txt".to_string(),
1767                                            syntax_node: vec![],
1768                                            content: Some(".".into()),
1769                                        }]),
1770                                    })
1771                                    .unwrap(),
1772                                },
1773                            },
1774                        }],
1775                    },
1776                    finish_reason: None,
1777                }],
1778                usage: Usage {
1779                    prompt_tokens: 0,
1780                    completion_tokens: 0,
1781                    total_tokens: 0,
1782                },
1783            })
1784            .unwrap();
1785        refresh_task.await.unwrap();
1786
1787        zeta.update(cx, |zeta, _cx| {
1788            zeta.discard_current_prediction(&project);
1789        });
1790
1791        // Prediction for another file
1792        let prediction_task = zeta.update(cx, |zeta, cx| {
1793            zeta.refresh_prediction(&project, &buffer1, position, cx)
1794        });
1795        let (_request, respond_tx) = req_rx.next().await.unwrap();
1796        respond_tx
1797            .send(model_response(indoc! {r#"
1798                --- a/root/2.txt
1799                +++ b/root/2.txt
1800                 Hola!
1801                -Como
1802                +Como estas?
1803                 Adios
1804            "#}))
1805            .unwrap();
1806        prediction_task.await.unwrap();
1807        zeta.read_with(cx, |zeta, cx| {
1808            let prediction = zeta
1809                .current_prediction_for_buffer(&buffer1, &project, cx)
1810                .unwrap();
1811            assert_matches!(
1812                prediction,
1813                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1814            );
1815        });
1816
1817        let buffer2 = project
1818            .update(cx, |project, cx| {
1819                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1820                project.open_buffer(path, cx)
1821            })
1822            .await
1823            .unwrap();
1824
1825        zeta.read_with(cx, |zeta, cx| {
1826            let prediction = zeta
1827                .current_prediction_for_buffer(&buffer2, &project, cx)
1828                .unwrap();
1829            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1830        });
1831    }
1832
1833    #[gpui::test]
1834    async fn test_simple_request(cx: &mut TestAppContext) {
1835        let (zeta, mut req_rx) = init_test(cx);
1836        let fs = FakeFs::new(cx.executor());
1837        fs.insert_tree(
1838            "/root",
1839            json!({
1840                "foo.md":  "Hello!\nHow\nBye\n"
1841            }),
1842        )
1843        .await;
1844        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1845
1846        let buffer = project
1847            .update(cx, |project, cx| {
1848                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1849                project.open_buffer(path, cx)
1850            })
1851            .await
1852            .unwrap();
1853        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1854        let position = snapshot.anchor_before(language::Point::new(1, 3));
1855
1856        let prediction_task = zeta.update(cx, |zeta, cx| {
1857            zeta.request_prediction(&project, &buffer, position, cx)
1858        });
1859
1860        let (_, respond_tx) = req_rx.next().await.unwrap();
1861
1862        // TODO Put back when we have a structured request again
1863        // assert_eq!(
1864        //     request.excerpt_path.as_ref(),
1865        //     Path::new(path!("root/foo.md"))
1866        // );
1867        // assert_eq!(
1868        //     request.cursor_point,
1869        //     Point {
1870        //         line: Line(1),
1871        //         column: 3
1872        //     }
1873        // );
1874
1875        respond_tx
1876            .send(model_response(indoc! { r"
1877                --- a/root/foo.md
1878                +++ b/root/foo.md
1879                @@ ... @@
1880                 Hello!
1881                -How
1882                +How are you?
1883                 Bye
1884            "}))
1885            .unwrap();
1886
1887        let prediction = prediction_task.await.unwrap().unwrap();
1888
1889        assert_eq!(prediction.edits.len(), 1);
1890        assert_eq!(
1891            prediction.edits[0].0.to_point(&snapshot).start,
1892            language::Point::new(1, 3)
1893        );
1894        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1895    }
1896
1897    #[gpui::test]
1898    async fn test_request_events(cx: &mut TestAppContext) {
1899        let (zeta, mut req_rx) = init_test(cx);
1900        let fs = FakeFs::new(cx.executor());
1901        fs.insert_tree(
1902            "/root",
1903            json!({
1904                "foo.md": "Hello!\n\nBye\n"
1905            }),
1906        )
1907        .await;
1908        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1909
1910        let buffer = project
1911            .update(cx, |project, cx| {
1912                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1913                project.open_buffer(path, cx)
1914            })
1915            .await
1916            .unwrap();
1917
1918        zeta.update(cx, |zeta, cx| {
1919            zeta.register_buffer(&buffer, &project, cx);
1920        });
1921
1922        buffer.update(cx, |buffer, cx| {
1923            buffer.edit(vec![(7..7, "How")], None, cx);
1924        });
1925
1926        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1927        let position = snapshot.anchor_before(language::Point::new(1, 3));
1928
1929        let prediction_task = zeta.update(cx, |zeta, cx| {
1930            zeta.request_prediction(&project, &buffer, position, cx)
1931        });
1932
1933        let (request, respond_tx) = req_rx.next().await.unwrap();
1934
1935        let prompt = prompt_from_request(&request);
1936        assert!(
1937            prompt.contains(indoc! {"
1938            --- a/root/foo.md
1939            +++ b/root/foo.md
1940            @@ -1,3 +1,3 @@
1941             Hello!
1942            -
1943            +How
1944             Bye
1945        "}),
1946            "{prompt}"
1947        );
1948
1949        respond_tx
1950            .send(model_response(indoc! {r#"
1951                --- a/root/foo.md
1952                +++ b/root/foo.md
1953                @@ ... @@
1954                 Hello!
1955                -How
1956                +How are you?
1957                 Bye
1958            "#}))
1959            .unwrap();
1960
1961        let prediction = prediction_task.await.unwrap().unwrap();
1962
1963        assert_eq!(prediction.edits.len(), 1);
1964        assert_eq!(
1965            prediction.edits[0].0.to_point(&snapshot).start,
1966            language::Point::new(1, 3)
1967        );
1968        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1969    }
1970
1971    // Skipped until we start including diagnostics in prompt
1972    // #[gpui::test]
1973    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1974    //     let (zeta, mut req_rx) = init_test(cx);
1975    //     let fs = FakeFs::new(cx.executor());
1976    //     fs.insert_tree(
1977    //         "/root",
1978    //         json!({
1979    //             "foo.md": "Hello!\nBye"
1980    //         }),
1981    //     )
1982    //     .await;
1983    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1984
1985    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1986    //     let diagnostic = lsp::Diagnostic {
1987    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1988    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
1989    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1990    //         ..Default::default()
1991    //     };
1992
1993    //     project.update(cx, |project, cx| {
1994    //         project.lsp_store().update(cx, |lsp_store, cx| {
1995    //             // Create some diagnostics
1996    //             lsp_store
1997    //                 .update_diagnostics(
1998    //                     LanguageServerId(0),
1999    //                     lsp::PublishDiagnosticsParams {
2000    //                         uri: path_to_buffer_uri.clone(),
2001    //                         diagnostics: vec![diagnostic],
2002    //                         version: None,
2003    //                     },
2004    //                     None,
2005    //                     language::DiagnosticSourceKind::Pushed,
2006    //                     &[],
2007    //                     cx,
2008    //                 )
2009    //                 .unwrap();
2010    //         });
2011    //     });
2012
2013    //     let buffer = project
2014    //         .update(cx, |project, cx| {
2015    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2016    //             project.open_buffer(path, cx)
2017    //         })
2018    //         .await
2019    //         .unwrap();
2020
2021    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2022    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2023
2024    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2025    //         zeta.request_prediction(&project, &buffer, position, cx)
2026    //     });
2027
2028    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2029
2030    //     assert_eq!(request.diagnostic_groups.len(), 1);
2031    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2032    //         .unwrap();
2033    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2034    //     assert_eq!(
2035    //         value,
2036    //         json!({
2037    //             "entries": [{
2038    //                 "range": {
2039    //                     "start": 8,
2040    //                     "end": 10
2041    //                 },
2042    //                 "diagnostic": {
2043    //                     "source": null,
2044    //                     "code": null,
2045    //                     "code_description": null,
2046    //                     "severity": 1,
2047    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2048    //                     "markdown": null,
2049    //                     "group_id": 0,
2050    //                     "is_primary": true,
2051    //                     "is_disk_based": false,
2052    //                     "is_unnecessary": false,
2053    //                     "source_kind": "Pushed",
2054    //                     "data": null,
2055    //                     "underline": true
2056    //                 }
2057    //             }],
2058    //             "primary_ix": 0
2059    //         })
2060    //     );
2061    // }
2062
2063    fn model_response(text: &str) -> open_ai::Response {
2064        open_ai::Response {
2065            id: Uuid::new_v4().to_string(),
2066            object: "response".into(),
2067            created: 0,
2068            model: "model".into(),
2069            choices: vec![open_ai::Choice {
2070                index: 0,
2071                message: open_ai::RequestMessage::Assistant {
2072                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2073                    tool_calls: vec![],
2074                },
2075                finish_reason: None,
2076            }],
2077            usage: Usage {
2078                prompt_tokens: 0,
2079                completion_tokens: 0,
2080                total_tokens: 0,
2081            },
2082        }
2083    }
2084
2085    fn prompt_from_request(request: &open_ai::Request) -> &str {
2086        assert_eq!(request.messages.len(), 1);
2087        let open_ai::RequestMessage::User {
2088            content: open_ai::MessageContent::Plain(content),
2089            ..
2090        } = &request.messages[0]
2091        else {
2092            panic!(
2093                "Request does not have single user message of type Plain. {:#?}",
2094                request
2095            );
2096        };
2097        content
2098    }
2099
2100    fn init_test(
2101        cx: &mut TestAppContext,
2102    ) -> (
2103        Entity<Zeta>,
2104        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2105    ) {
2106        cx.update(move |cx| {
2107            let settings_store = SettingsStore::test(cx);
2108            cx.set_global(settings_store);
2109            zlog::init_test();
2110
2111            let (req_tx, req_rx) = mpsc::unbounded();
2112
2113            let http_client = FakeHttpClient::create({
2114                move |req| {
2115                    let uri = req.uri().path().to_string();
2116                    let mut body = req.into_body();
2117                    let req_tx = req_tx.clone();
2118                    async move {
2119                        let resp = match uri.as_str() {
2120                            "/client/llm_tokens" => serde_json::to_string(&json!({
2121                                "token": "test"
2122                            }))
2123                            .unwrap(),
2124                            "/predict_edits/raw" => {
2125                                let mut buf = Vec::new();
2126                                body.read_to_end(&mut buf).await.ok();
2127                                let req = serde_json::from_slice(&buf).unwrap();
2128
2129                                let (res_tx, res_rx) = oneshot::channel();
2130                                req_tx.unbounded_send((req, res_tx)).unwrap();
2131                                serde_json::to_string(&res_rx.await?).unwrap()
2132                            }
2133                            _ => {
2134                                panic!("Unexpected path: {}", uri)
2135                            }
2136                        };
2137
2138                        Ok(Response::builder().body(resp.into()).unwrap())
2139                    }
2140                }
2141            });
2142
2143            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2144            client.cloud_client().set_credentials(1, "test".into());
2145
2146            language_model::init(client.clone(), cx);
2147
2148            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2149            let zeta = Zeta::global(&client, &user_store, cx);
2150
2151            (zeta, req_rx)
2152        })
2153    }
2154}