zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50
  51use crate::merge_excerpts::merge_excerpts;
  52use crate::prediction::EditPrediction;
  53pub use crate::prediction::EditPredictionId;
  54pub use provider::ZetaEditPredictionProvider;
  55
  56/// Maximum number of events to track.
  57const MAX_EVENT_COUNT: usize = 16;
  58
  59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  60    max_bytes: 512,
  61    min_bytes: 128,
  62    target_before_cursor_over_total_bytes: 0.5,
  63};
  64
  65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  66    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  67
  68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  69    excerpt: DEFAULT_EXCERPT_OPTIONS,
  70};
  71
  72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  73    EditPredictionContextOptions {
  74        use_imports: true,
  75        max_retrieved_declarations: 0,
  76        excerpt: DEFAULT_EXCERPT_OPTIONS,
  77        score: EditPredictionScoreOptions {
  78            omit_excerpt_overlaps: true,
  79        },
  80    };
  81
  82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  83    context: DEFAULT_CONTEXT_OPTIONS,
  84    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  85    max_diagnostic_bytes: 2048,
  86    prompt_format: PromptFormat::DEFAULT,
  87    file_indexing_parallelism: 1,
  88    buffer_change_grouping_interval: Duration::from_secs(1),
  89};
  90
  91static USE_OLLAMA: LazyLock<bool> =
  92    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  93static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  94    env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
  95        "qwen3-coder:30b".to_string()
  96    } else {
  97        "yqvev8r3".to_string()
  98    })
  99});
 100static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 101    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 102        if *USE_OLLAMA {
 103            Some("http://localhost:11434/v1/chat/completions".into())
 104        } else {
 105            None
 106        }
 107    })
 108});
 109
 110pub struct Zeta2FeatureFlag;
 111
 112impl FeatureFlag for Zeta2FeatureFlag {
 113    const NAME: &'static str = "zeta2";
 114
 115    fn enabled_for_staff() -> bool {
 116        false
 117    }
 118}
 119
 120#[derive(Clone)]
 121struct ZetaGlobal(Entity<Zeta>);
 122
 123impl Global for ZetaGlobal {}
 124
 125pub struct Zeta {
 126    client: Arc<Client>,
 127    user_store: Entity<UserStore>,
 128    llm_token: LlmApiToken,
 129    _llm_token_subscription: Subscription,
 130    projects: HashMap<EntityId, ZetaProject>,
 131    options: ZetaOptions,
 132    update_required: bool,
 133    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 134}
 135
 136#[derive(Debug, Clone, PartialEq)]
 137pub struct ZetaOptions {
 138    pub context: ContextMode,
 139    pub max_prompt_bytes: usize,
 140    pub max_diagnostic_bytes: usize,
 141    pub prompt_format: predict_edits_v3::PromptFormat,
 142    pub file_indexing_parallelism: usize,
 143    pub buffer_change_grouping_interval: Duration,
 144}
 145
 146#[derive(Debug, Clone, PartialEq)]
 147pub enum ContextMode {
 148    Agentic(AgenticContextOptions),
 149    Syntax(EditPredictionContextOptions),
 150}
 151
 152#[derive(Debug, Clone, PartialEq)]
 153pub struct AgenticContextOptions {
 154    pub excerpt: EditPredictionExcerptOptions,
 155}
 156
 157impl ContextMode {
 158    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 159        match self {
 160            ContextMode::Agentic(options) => &options.excerpt,
 161            ContextMode::Syntax(options) => &options.excerpt,
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub enum ZetaDebugInfo {
 168    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 169    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 170    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 171    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 172    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 173}
 174
 175#[derive(Debug)]
 176pub struct ZetaContextRetrievalStartedDebugInfo {
 177    pub project: Entity<Project>,
 178    pub timestamp: Instant,
 179    pub search_prompt: String,
 180}
 181
 182#[derive(Debug)]
 183pub struct ZetaContextRetrievalDebugInfo {
 184    pub project: Entity<Project>,
 185    pub timestamp: Instant,
 186}
 187
 188#[derive(Debug)]
 189pub struct ZetaEditPredictionDebugInfo {
 190    pub request: predict_edits_v3::PredictEditsRequest,
 191    pub retrieval_time: TimeDelta,
 192    pub buffer: WeakEntity<Buffer>,
 193    pub position: language::Anchor,
 194    pub local_prompt: Result<String, String>,
 195    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 196}
 197
 198#[derive(Debug)]
 199pub struct ZetaSearchQueryDebugInfo {
 200    pub project: Entity<Project>,
 201    pub timestamp: Instant,
 202    pub search_queries: Vec<SearchToolQuery>,
 203}
 204
 205pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 206
 207struct ZetaProject {
 208    syntax_index: Entity<SyntaxIndex>,
 209    events: VecDeque<Event>,
 210    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 211    current_prediction: Option<CurrentEditPrediction>,
 212    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 213    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 214    refresh_context_debounce_task: Option<Task<Option<()>>>,
 215    refresh_context_timestamp: Option<Instant>,
 216}
 217
 218#[derive(Debug, Clone)]
 219struct CurrentEditPrediction {
 220    pub requested_by_buffer_id: EntityId,
 221    pub prediction: EditPrediction,
 222}
 223
 224impl CurrentEditPrediction {
 225    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 226        let Some(new_edits) = self
 227            .prediction
 228            .interpolate(&self.prediction.buffer.read(cx))
 229        else {
 230            return false;
 231        };
 232
 233        if self.prediction.buffer != old_prediction.prediction.buffer {
 234            return true;
 235        }
 236
 237        let Some(old_edits) = old_prediction
 238            .prediction
 239            .interpolate(&old_prediction.prediction.buffer.read(cx))
 240        else {
 241            return true;
 242        };
 243
 244        // This reduces the occurrence of UI thrash from replacing edits
 245        //
 246        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 247        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 248            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 249            && old_edits.len() == 1
 250            && new_edits.len() == 1
 251        {
 252            let (old_range, old_text) = &old_edits[0];
 253            let (new_range, new_text) = &new_edits[0];
 254            new_range == old_range && new_text.starts_with(old_text.as_ref())
 255        } else {
 256            true
 257        }
 258    }
 259}
 260
 261/// A prediction from the perspective of a buffer.
 262#[derive(Debug)]
 263enum BufferEditPrediction<'a> {
 264    Local { prediction: &'a EditPrediction },
 265    Jump { prediction: &'a EditPrediction },
 266}
 267
 268struct RegisteredBuffer {
 269    snapshot: BufferSnapshot,
 270    _subscriptions: [gpui::Subscription; 2],
 271}
 272
 273#[derive(Clone)]
 274pub enum Event {
 275    BufferChange {
 276        old_snapshot: BufferSnapshot,
 277        new_snapshot: BufferSnapshot,
 278        timestamp: Instant,
 279    },
 280}
 281
 282impl Event {
 283    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 284        match self {
 285            Event::BufferChange {
 286                old_snapshot,
 287                new_snapshot,
 288                ..
 289            } => {
 290                let path = new_snapshot.file().map(|f| f.full_path(cx));
 291
 292                let old_path = old_snapshot.file().and_then(|f| {
 293                    let old_path = f.full_path(cx);
 294                    if Some(&old_path) != path.as_ref() {
 295                        Some(old_path)
 296                    } else {
 297                        None
 298                    }
 299                });
 300
 301                // TODO [zeta2] move to bg?
 302                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 303
 304                if path == old_path && diff.is_empty() {
 305                    None
 306                } else {
 307                    Some(predict_edits_v3::Event::BufferChange {
 308                        old_path,
 309                        path,
 310                        diff,
 311                        //todo: Actually detect if this edit was predicted or not
 312                        predicted: false,
 313                    })
 314                }
 315            }
 316        }
 317    }
 318}
 319
 320impl Zeta {
 321    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 322        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 323    }
 324
 325    pub fn global(
 326        client: &Arc<Client>,
 327        user_store: &Entity<UserStore>,
 328        cx: &mut App,
 329    ) -> Entity<Self> {
 330        cx.try_global::<ZetaGlobal>()
 331            .map(|global| global.0.clone())
 332            .unwrap_or_else(|| {
 333                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 334                cx.set_global(ZetaGlobal(zeta.clone()));
 335                zeta
 336            })
 337    }
 338
 339    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 340        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 341
 342        Self {
 343            projects: HashMap::default(),
 344            client,
 345            user_store,
 346            options: DEFAULT_OPTIONS,
 347            llm_token: LlmApiToken::default(),
 348            _llm_token_subscription: cx.subscribe(
 349                &refresh_llm_token_listener,
 350                |this, _listener, _event, cx| {
 351                    let client = this.client.clone();
 352                    let llm_token = this.llm_token.clone();
 353                    cx.spawn(async move |_this, _cx| {
 354                        llm_token.refresh(&client).await?;
 355                        anyhow::Ok(())
 356                    })
 357                    .detach_and_log_err(cx);
 358                },
 359            ),
 360            update_required: false,
 361            debug_tx: None,
 362        }
 363    }
 364
 365    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 366        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 367        self.debug_tx = Some(debug_watch_tx);
 368        debug_watch_rx
 369    }
 370
 371    pub fn options(&self) -> &ZetaOptions {
 372        &self.options
 373    }
 374
 375    pub fn set_options(&mut self, options: ZetaOptions) {
 376        self.options = options;
 377    }
 378
 379    pub fn clear_history(&mut self) {
 380        for zeta_project in self.projects.values_mut() {
 381            zeta_project.events.clear();
 382        }
 383    }
 384
 385    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 386        self.projects
 387            .get(&project.entity_id())
 388            .map(|project| project.events.iter())
 389            .into_iter()
 390            .flatten()
 391    }
 392
 393    pub fn context_for_project(
 394        &self,
 395        project: &Entity<Project>,
 396    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 397        self.projects
 398            .get(&project.entity_id())
 399            .and_then(|project| {
 400                Some(
 401                    project
 402                        .context
 403                        .as_ref()?
 404                        .iter()
 405                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 406                )
 407            })
 408            .into_iter()
 409            .flatten()
 410    }
 411
 412    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 413        self.user_store.read(cx).edit_prediction_usage()
 414    }
 415
 416    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 417        self.get_or_init_zeta_project(project, cx);
 418    }
 419
 420    pub fn register_buffer(
 421        &mut self,
 422        buffer: &Entity<Buffer>,
 423        project: &Entity<Project>,
 424        cx: &mut Context<Self>,
 425    ) {
 426        let zeta_project = self.get_or_init_zeta_project(project, cx);
 427        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 428    }
 429
 430    fn get_or_init_zeta_project(
 431        &mut self,
 432        project: &Entity<Project>,
 433        cx: &mut App,
 434    ) -> &mut ZetaProject {
 435        self.projects
 436            .entry(project.entity_id())
 437            .or_insert_with(|| ZetaProject {
 438                syntax_index: cx.new(|cx| {
 439                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 440                }),
 441                events: VecDeque::new(),
 442                registered_buffers: HashMap::default(),
 443                current_prediction: None,
 444                context: None,
 445                refresh_context_task: None,
 446                refresh_context_debounce_task: None,
 447                refresh_context_timestamp: None,
 448            })
 449    }
 450
 451    fn register_buffer_impl<'a>(
 452        zeta_project: &'a mut ZetaProject,
 453        buffer: &Entity<Buffer>,
 454        project: &Entity<Project>,
 455        cx: &mut Context<Self>,
 456    ) -> &'a mut RegisteredBuffer {
 457        let buffer_id = buffer.entity_id();
 458        match zeta_project.registered_buffers.entry(buffer_id) {
 459            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 460            hash_map::Entry::Vacant(entry) => {
 461                let snapshot = buffer.read(cx).snapshot();
 462                let project_entity_id = project.entity_id();
 463                entry.insert(RegisteredBuffer {
 464                    snapshot,
 465                    _subscriptions: [
 466                        cx.subscribe(buffer, {
 467                            let project = project.downgrade();
 468                            move |this, buffer, event, cx| {
 469                                if let language::BufferEvent::Edited = event
 470                                    && let Some(project) = project.upgrade()
 471                                {
 472                                    this.report_changes_for_buffer(&buffer, &project, cx);
 473                                }
 474                            }
 475                        }),
 476                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 477                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 478                            else {
 479                                return;
 480                            };
 481                            zeta_project.registered_buffers.remove(&buffer_id);
 482                        }),
 483                    ],
 484                })
 485            }
 486        }
 487    }
 488
 489    fn report_changes_for_buffer(
 490        &mut self,
 491        buffer: &Entity<Buffer>,
 492        project: &Entity<Project>,
 493        cx: &mut Context<Self>,
 494    ) -> BufferSnapshot {
 495        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 496        let zeta_project = self.get_or_init_zeta_project(project, cx);
 497        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 498
 499        let new_snapshot = buffer.read(cx).snapshot();
 500        if new_snapshot.version != registered_buffer.snapshot.version {
 501            let old_snapshot =
 502                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 503            Self::push_event(
 504                zeta_project,
 505                buffer_change_grouping_interval,
 506                Event::BufferChange {
 507                    old_snapshot,
 508                    new_snapshot: new_snapshot.clone(),
 509                    timestamp: Instant::now(),
 510                },
 511            );
 512        }
 513
 514        new_snapshot
 515    }
 516
 517    fn push_event(
 518        zeta_project: &mut ZetaProject,
 519        buffer_change_grouping_interval: Duration,
 520        event: Event,
 521    ) {
 522        let events = &mut zeta_project.events;
 523
 524        if buffer_change_grouping_interval > Duration::ZERO
 525            && let Some(Event::BufferChange {
 526                new_snapshot: last_new_snapshot,
 527                timestamp: last_timestamp,
 528                ..
 529            }) = events.back_mut()
 530        {
 531            // Coalesce edits for the same buffer when they happen one after the other.
 532            let Event::BufferChange {
 533                old_snapshot,
 534                new_snapshot,
 535                timestamp,
 536            } = &event;
 537
 538            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 539                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 540                && old_snapshot.version == last_new_snapshot.version
 541            {
 542                *last_new_snapshot = new_snapshot.clone();
 543                *last_timestamp = *timestamp;
 544                return;
 545            }
 546        }
 547
 548        if events.len() >= MAX_EVENT_COUNT {
 549            // These are halved instead of popping to improve prompt caching.
 550            events.drain(..MAX_EVENT_COUNT / 2);
 551        }
 552
 553        events.push_back(event);
 554    }
 555
 556    fn current_prediction_for_buffer(
 557        &self,
 558        buffer: &Entity<Buffer>,
 559        project: &Entity<Project>,
 560        cx: &App,
 561    ) -> Option<BufferEditPrediction<'_>> {
 562        let project_state = self.projects.get(&project.entity_id())?;
 563
 564        let CurrentEditPrediction {
 565            requested_by_buffer_id,
 566            prediction,
 567        } = project_state.current_prediction.as_ref()?;
 568
 569        if prediction.targets_buffer(buffer.read(cx)) {
 570            Some(BufferEditPrediction::Local { prediction })
 571        } else if *requested_by_buffer_id == buffer.entity_id() {
 572            Some(BufferEditPrediction::Jump { prediction })
 573        } else {
 574            None
 575        }
 576    }
 577
 578    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 579        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 580            return;
 581        };
 582
 583        let Some(prediction) = project_state.current_prediction.take() else {
 584            return;
 585        };
 586        let request_id = prediction.prediction.id.to_string();
 587
 588        let client = self.client.clone();
 589        let llm_token = self.llm_token.clone();
 590        let app_version = AppVersion::global(cx);
 591        cx.spawn(async move |this, cx| {
 592            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 593                http_client::Url::parse(&predict_edits_url)?
 594            } else {
 595                client
 596                    .http_client()
 597                    .build_zed_llm_url("/predict_edits/accept", &[])?
 598            };
 599
 600            let response = cx
 601                .background_spawn(Self::send_api_request::<()>(
 602                    move |builder| {
 603                        let req = builder.uri(url.as_ref()).body(
 604                            serde_json::to_string(&AcceptEditPredictionBody {
 605                                request_id: request_id.clone(),
 606                            })?
 607                            .into(),
 608                        );
 609                        Ok(req?)
 610                    },
 611                    client,
 612                    llm_token,
 613                    app_version,
 614                ))
 615                .await;
 616
 617            Self::handle_api_response(&this, response, cx)?;
 618            anyhow::Ok(())
 619        })
 620        .detach_and_log_err(cx);
 621    }
 622
 623    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 624        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 625            project_state.current_prediction.take();
 626        };
 627    }
 628
 629    pub fn refresh_prediction(
 630        &mut self,
 631        project: &Entity<Project>,
 632        buffer: &Entity<Buffer>,
 633        position: language::Anchor,
 634        cx: &mut Context<Self>,
 635    ) -> Task<Result<()>> {
 636        let request_task = self.request_prediction(project, buffer, position, cx);
 637        let buffer = buffer.clone();
 638        let project = project.clone();
 639
 640        cx.spawn(async move |this, cx| {
 641            if let Some(prediction) = request_task.await? {
 642                this.update(cx, |this, cx| {
 643                    let project_state = this
 644                        .projects
 645                        .get_mut(&project.entity_id())
 646                        .context("Project not found")?;
 647
 648                    let new_prediction = CurrentEditPrediction {
 649                        requested_by_buffer_id: buffer.entity_id(),
 650                        prediction: prediction,
 651                    };
 652
 653                    if project_state
 654                        .current_prediction
 655                        .as_ref()
 656                        .is_none_or(|old_prediction| {
 657                            new_prediction.should_replace_prediction(&old_prediction, cx)
 658                        })
 659                    {
 660                        project_state.current_prediction = Some(new_prediction);
 661                    }
 662                    anyhow::Ok(())
 663                })??;
 664            }
 665            Ok(())
 666        })
 667    }
 668
 669    pub fn request_prediction(
 670        &mut self,
 671        project: &Entity<Project>,
 672        active_buffer: &Entity<Buffer>,
 673        position: language::Anchor,
 674        cx: &mut Context<Self>,
 675    ) -> Task<Result<Option<EditPrediction>>> {
 676        let project_state = self.projects.get(&project.entity_id());
 677
 678        let index_state = project_state.map(|state| {
 679            state
 680                .syntax_index
 681                .read_with(cx, |index, _cx| index.state().clone())
 682        });
 683        let options = self.options.clone();
 684        let active_snapshot = active_buffer.read(cx).snapshot();
 685        let Some(excerpt_path) = active_snapshot
 686            .file()
 687            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 688        else {
 689            return Task::ready(Err(anyhow!("No file path for excerpt")));
 690        };
 691        let client = self.client.clone();
 692        let llm_token = self.llm_token.clone();
 693        let app_version = AppVersion::global(cx);
 694        let worktree_snapshots = project
 695            .read(cx)
 696            .worktrees(cx)
 697            .map(|worktree| worktree.read(cx).snapshot())
 698            .collect::<Vec<_>>();
 699        let debug_tx = self.debug_tx.clone();
 700
 701        let events = project_state
 702            .map(|state| {
 703                state
 704                    .events
 705                    .iter()
 706                    .filter_map(|event| event.to_request_event(cx))
 707                    .collect::<Vec<_>>()
 708            })
 709            .unwrap_or_default();
 710
 711        let diagnostics = active_snapshot.diagnostic_sets().clone();
 712
 713        let parent_abs_path =
 714            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 715                let mut path = f.worktree.read(cx).absolutize(&f.path);
 716                if path.pop() { Some(path) } else { None }
 717            });
 718
 719        // TODO data collection
 720        let can_collect_data = cx.is_staff();
 721
 722        let mut included_files = project_state
 723            .and_then(|project_state| project_state.context.as_ref())
 724            .unwrap_or(&HashMap::default())
 725            .iter()
 726            .filter_map(|(buffer_entity, ranges)| {
 727                let buffer = buffer_entity.read(cx);
 728                Some((
 729                    buffer_entity.clone(),
 730                    buffer.snapshot(),
 731                    buffer.file()?.full_path(cx).into(),
 732                    ranges.clone(),
 733                ))
 734            })
 735            .collect::<Vec<_>>();
 736
 737        let request_task = cx.background_spawn({
 738            let active_buffer = active_buffer.clone();
 739            async move {
 740                let index_state = if let Some(index_state) = index_state {
 741                    Some(index_state.lock_owned().await)
 742                } else {
 743                    None
 744                };
 745
 746                let cursor_offset = position.to_offset(&active_snapshot);
 747                let cursor_point = cursor_offset.to_point(&active_snapshot);
 748
 749                let before_retrieval = chrono::Utc::now();
 750
 751                let (diagnostic_groups, diagnostic_groups_truncated) =
 752                    Self::gather_nearby_diagnostics(
 753                        cursor_offset,
 754                        &diagnostics,
 755                        &active_snapshot,
 756                        options.max_diagnostic_bytes,
 757                    );
 758
 759                let cloud_request = match options.context {
 760                    ContextMode::Agentic(context_options) => {
 761                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 762                            cursor_point,
 763                            &active_snapshot,
 764                            &context_options.excerpt,
 765                            index_state.as_deref(),
 766                        ) else {
 767                            return Ok((None, None));
 768                        };
 769
 770                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 771                            ..active_snapshot.anchor_before(excerpt.range.end);
 772
 773                        if let Some(buffer_ix) =
 774                            included_files.iter().position(|(_, snapshot, _, _)| {
 775                                snapshot.remote_id() == active_snapshot.remote_id()
 776                            })
 777                        {
 778                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 779                            let range_ix = ranges
 780                                .binary_search_by(|probe| {
 781                                    probe
 782                                        .start
 783                                        .cmp(&excerpt_anchor_range.start, buffer)
 784                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 785                                })
 786                                .unwrap_or_else(|ix| ix);
 787
 788                            ranges.insert(range_ix, excerpt_anchor_range);
 789                            let last_ix = included_files.len() - 1;
 790                            included_files.swap(buffer_ix, last_ix);
 791                        } else {
 792                            included_files.push((
 793                                active_buffer.clone(),
 794                                active_snapshot,
 795                                excerpt_path.clone(),
 796                                vec![excerpt_anchor_range],
 797                            ));
 798                        }
 799
 800                        let included_files = included_files
 801                            .iter()
 802                            .map(|(_, buffer, path, ranges)| {
 803                                let excerpts = merge_excerpts(
 804                                    &buffer,
 805                                    ranges.iter().map(|range| {
 806                                        let point_range = range.to_point(&buffer);
 807                                        Line(point_range.start.row)..Line(point_range.end.row)
 808                                    }),
 809                                );
 810                                predict_edits_v3::IncludedFile {
 811                                    path: path.clone(),
 812                                    max_row: Line(buffer.max_point().row),
 813                                    excerpts,
 814                                }
 815                            })
 816                            .collect::<Vec<_>>();
 817
 818                        predict_edits_v3::PredictEditsRequest {
 819                            excerpt_path,
 820                            excerpt: String::new(),
 821                            excerpt_line_range: Line(0)..Line(0),
 822                            excerpt_range: 0..0,
 823                            cursor_point: predict_edits_v3::Point {
 824                                line: predict_edits_v3::Line(cursor_point.row),
 825                                column: cursor_point.column,
 826                            },
 827                            included_files,
 828                            referenced_declarations: vec![],
 829                            events,
 830                            can_collect_data,
 831                            diagnostic_groups,
 832                            diagnostic_groups_truncated,
 833                            debug_info: debug_tx.is_some(),
 834                            prompt_max_bytes: Some(options.max_prompt_bytes),
 835                            prompt_format: options.prompt_format,
 836                            // TODO [zeta2]
 837                            signatures: vec![],
 838                            excerpt_parent: None,
 839                            git_info: None,
 840                        }
 841                    }
 842                    ContextMode::Syntax(context_options) => {
 843                        let Some(context) = EditPredictionContext::gather_context(
 844                            cursor_point,
 845                            &active_snapshot,
 846                            parent_abs_path.as_deref(),
 847                            &context_options,
 848                            index_state.as_deref(),
 849                        ) else {
 850                            return Ok((None, None));
 851                        };
 852
 853                        make_syntax_context_cloud_request(
 854                            excerpt_path,
 855                            context,
 856                            events,
 857                            can_collect_data,
 858                            diagnostic_groups,
 859                            diagnostic_groups_truncated,
 860                            None,
 861                            debug_tx.is_some(),
 862                            &worktree_snapshots,
 863                            index_state.as_deref(),
 864                            Some(options.max_prompt_bytes),
 865                            options.prompt_format,
 866                        )
 867                    }
 868                };
 869
 870                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 871
 872                let retrieval_time = chrono::Utc::now() - before_retrieval;
 873
 874                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 875                    let (response_tx, response_rx) = oneshot::channel();
 876
 877                    debug_tx
 878                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 879                            ZetaEditPredictionDebugInfo {
 880                                request: cloud_request.clone(),
 881                                retrieval_time,
 882                                buffer: active_buffer.downgrade(),
 883                                local_prompt: match prompt_result.as_ref() {
 884                                    Ok((prompt, _)) => Ok(prompt.clone()),
 885                                    Err(err) => Err(err.to_string()),
 886                                },
 887                                position,
 888                                response_rx,
 889                            },
 890                        ))
 891                        .ok();
 892                    Some(response_tx)
 893                } else {
 894                    None
 895                };
 896
 897                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 898                    if let Some(debug_response_tx) = debug_response_tx {
 899                        debug_response_tx
 900                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 901                            .ok();
 902                    }
 903                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 904                }
 905
 906                let (prompt, _) = prompt_result?;
 907                let request = open_ai::Request {
 908                    model: MODEL_ID.clone(),
 909                    messages: vec![open_ai::RequestMessage::User {
 910                        content: open_ai::MessageContent::Plain(prompt),
 911                    }],
 912                    stream: false,
 913                    max_completion_tokens: None,
 914                    stop: Default::default(),
 915                    temperature: 0.7,
 916                    tool_choice: None,
 917                    parallel_tool_calls: None,
 918                    tools: vec![],
 919                    prompt_cache_key: None,
 920                    reasoning_effort: None,
 921                };
 922
 923                log::trace!("Sending edit prediction request");
 924
 925                let before_request = chrono::Utc::now();
 926                let response =
 927                    Self::send_raw_llm_request(client, llm_token, app_version, request).await;
 928                let request_time = chrono::Utc::now() - before_request;
 929
 930                log::trace!("Got edit prediction response");
 931
 932                if let Some(debug_response_tx) = debug_response_tx {
 933                    debug_response_tx
 934                        .send((
 935                            response
 936                                .as_ref()
 937                                .map_err(|err| err.to_string())
 938                                .map(|response| response.0.clone()),
 939                            request_time,
 940                        ))
 941                        .ok();
 942                }
 943
 944                let (res, usage) = response?;
 945                let request_id = EditPredictionId(res.id.clone().into());
 946                let Some(output_text) = text_from_response(res) else {
 947                    return Ok((None, usage))
 948                };
 949
 950                let (edited_buffer_snapshot, edits) =
 951                    crate::udiff::parse_diff(&output_text, |path| {
 952                        included_files
 953                            .iter()
 954                            .find_map(|(_, buffer, probe_path, ranges)| {
 955                                if probe_path.as_ref() == path {
 956                                    Some((buffer, ranges.as_slice()))
 957                                } else {
 958                                    None
 959                                }
 960                            })
 961                    })
 962                    .await?;
 963
 964                let edited_buffer = included_files
 965                    .iter()
 966                    .find_map(|(buffer, snapshot, _, _)| {
 967                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
 968                            Some(buffer.clone())
 969                        } else {
 970                            None
 971                        }
 972                    })
 973                    .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
 974
 975                anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
 976            }
 977        });
 978
 979        cx.spawn({
 980            async move |this, cx| {
 981                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
 982                    Self::handle_api_response(&this, request_task.await, cx)?
 983                else {
 984                    return Ok(None);
 985                };
 986
 987                // TODO telemetry: duration, etc
 988                Ok(
 989                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
 990                        .await,
 991                )
 992            }
 993        })
 994    }
 995
 996    async fn send_raw_llm_request(
 997        client: Arc<Client>,
 998        llm_token: LlmApiToken,
 999        app_version: SemanticVersion,
1000        request: open_ai::Request,
1001    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1002        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1003            http_client::Url::parse(&predict_edits_url)?
1004        } else {
1005            client
1006                .http_client()
1007                .build_zed_llm_url("/predict_edits/raw", &[])?
1008        };
1009
1010        Self::send_api_request(
1011            |builder| {
1012                let req = builder
1013                    .uri(url.as_ref())
1014                    .body(serde_json::to_string(&request)?.into());
1015                Ok(req?)
1016            },
1017            client,
1018            llm_token,
1019            app_version,
1020        )
1021        .await
1022    }
1023
1024    fn handle_api_response<T>(
1025        this: &WeakEntity<Self>,
1026        response: Result<(T, Option<EditPredictionUsage>)>,
1027        cx: &mut gpui::AsyncApp,
1028    ) -> Result<T> {
1029        match response {
1030            Ok((data, usage)) => {
1031                if let Some(usage) = usage {
1032                    this.update(cx, |this, cx| {
1033                        this.user_store.update(cx, |user_store, cx| {
1034                            user_store.update_edit_prediction_usage(usage, cx);
1035                        });
1036                    })
1037                    .ok();
1038                }
1039                Ok(data)
1040            }
1041            Err(err) => {
1042                if err.is::<ZedUpdateRequiredError>() {
1043                    cx.update(|cx| {
1044                        this.update(cx, |this, _cx| {
1045                            this.update_required = true;
1046                        })
1047                        .ok();
1048
1049                        let error_message: SharedString = err.to_string().into();
1050                        show_app_notification(
1051                            NotificationId::unique::<ZedUpdateRequiredError>(),
1052                            cx,
1053                            move |cx| {
1054                                cx.new(|cx| {
1055                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1056                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1057                                })
1058                            },
1059                        );
1060                    })
1061                    .ok();
1062                }
1063                Err(err)
1064            }
1065        }
1066    }
1067
1068    async fn send_api_request<Res>(
1069        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1070        client: Arc<Client>,
1071        llm_token: LlmApiToken,
1072        app_version: SemanticVersion,
1073    ) -> Result<(Res, Option<EditPredictionUsage>)>
1074    where
1075        Res: DeserializeOwned,
1076    {
1077        let http_client = client.http_client();
1078        let mut token = llm_token.acquire(&client).await?;
1079        let mut did_retry = false;
1080
1081        loop {
1082            let request_builder = http_client::Request::builder().method(Method::POST);
1083
1084            let request = build(
1085                request_builder
1086                    .header("Content-Type", "application/json")
1087                    .header("Authorization", format!("Bearer {}", token))
1088                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1089            )?;
1090
1091            let mut response = http_client.send(request).await?;
1092
1093            if let Some(minimum_required_version) = response
1094                .headers()
1095                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1096                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1097            {
1098                anyhow::ensure!(
1099                    app_version >= minimum_required_version,
1100                    ZedUpdateRequiredError {
1101                        minimum_version: minimum_required_version
1102                    }
1103                );
1104            }
1105
1106            if response.status().is_success() {
1107                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1108
1109                let mut body = Vec::new();
1110                response.body_mut().read_to_end(&mut body).await?;
1111                return Ok((serde_json::from_slice(&body)?, usage));
1112            } else if !did_retry
1113                && response
1114                    .headers()
1115                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1116                    .is_some()
1117            {
1118                did_retry = true;
1119                token = llm_token.refresh(&client).await?;
1120            } else {
1121                let mut body = String::new();
1122                response.body_mut().read_to_string(&mut body).await?;
1123                anyhow::bail!(
1124                    "Request failed with status: {:?}\nBody: {}",
1125                    response.status(),
1126                    body
1127                );
1128            }
1129        }
1130    }
1131
1132    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1133    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1134
1135    // Refresh the related excerpts when the user just beguns editing after
1136    // an idle period, and after they pause editing.
1137    fn refresh_context_if_needed(
1138        &mut self,
1139        project: &Entity<Project>,
1140        buffer: &Entity<language::Buffer>,
1141        cursor_position: language::Anchor,
1142        cx: &mut Context<Self>,
1143    ) {
1144        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1145            return;
1146        }
1147
1148        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1149            return;
1150        };
1151
1152        let now = Instant::now();
1153        let was_idle = zeta_project
1154            .refresh_context_timestamp
1155            .map_or(true, |timestamp| {
1156                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1157            });
1158        zeta_project.refresh_context_timestamp = Some(now);
1159        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1160            let buffer = buffer.clone();
1161            let project = project.clone();
1162            async move |this, cx| {
1163                if was_idle {
1164                    log::debug!("refetching edit prediction context after idle");
1165                } else {
1166                    cx.background_executor()
1167                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1168                        .await;
1169                    log::debug!("refetching edit prediction context after pause");
1170                }
1171                this.update(cx, |this, cx| {
1172                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1173
1174                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1175                        zeta_project.refresh_context_task = Some(task.log_err());
1176                    };
1177                })
1178                .ok()
1179            }
1180        }));
1181    }
1182
1183    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1184    // and avoid spawning more than one concurrent task.
1185    pub fn refresh_context(
1186        &mut self,
1187        project: Entity<Project>,
1188        buffer: Entity<language::Buffer>,
1189        cursor_position: language::Anchor,
1190        cx: &mut Context<Self>,
1191    ) -> Task<Result<()>> {
1192        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1193            return Task::ready(anyhow::Ok(()));
1194        };
1195
1196        let ContextMode::Agentic(options) = &self.options().context else {
1197            return Task::ready(anyhow::Ok(()));
1198        };
1199
1200        let snapshot = buffer.read(cx).snapshot();
1201        let cursor_point = cursor_position.to_point(&snapshot);
1202        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1203            cursor_point,
1204            &snapshot,
1205            &options.excerpt,
1206            None,
1207        ) else {
1208            return Task::ready(Ok(()));
1209        };
1210
1211        let app_version = AppVersion::global(cx);
1212        let client = self.client.clone();
1213        let llm_token = self.llm_token.clone();
1214        let debug_tx = self.debug_tx.clone();
1215        let current_file_path: Arc<Path> = snapshot
1216            .file()
1217            .map(|f| f.full_path(cx).into())
1218            .unwrap_or_else(|| Path::new("untitled").into());
1219
1220        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1221            predict_edits_v3::PlanContextRetrievalRequest {
1222                excerpt: cursor_excerpt.text(&snapshot).body,
1223                excerpt_path: current_file_path,
1224                excerpt_line_range: cursor_excerpt.line_range,
1225                cursor_file_max_row: Line(snapshot.max_point().row),
1226                events: zeta_project
1227                    .events
1228                    .iter()
1229                    .filter_map(|ev| ev.to_request_event(cx))
1230                    .collect(),
1231            },
1232        ) {
1233            Ok(prompt) => prompt,
1234            Err(err) => {
1235                return Task::ready(Err(err));
1236            }
1237        };
1238
1239        if let Some(debug_tx) = &debug_tx {
1240            debug_tx
1241                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1242                    ZetaContextRetrievalStartedDebugInfo {
1243                        project: project.clone(),
1244                        timestamp: Instant::now(),
1245                        search_prompt: prompt.clone(),
1246                    },
1247                ))
1248                .ok();
1249        }
1250
1251        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1252            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1253                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1254            );
1255
1256            let description = schema
1257                .get("description")
1258                .and_then(|description| description.as_str())
1259                .unwrap()
1260                .to_string();
1261
1262            (schema.into(), description)
1263        });
1264
1265        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1266
1267        let request = open_ai::Request {
1268            model: MODEL_ID.clone(),
1269            messages: vec![open_ai::RequestMessage::User {
1270                content: open_ai::MessageContent::Plain(prompt),
1271            }],
1272            stream: false,
1273            max_completion_tokens: None,
1274            stop: Default::default(),
1275            temperature: 0.7,
1276            tool_choice: None,
1277            parallel_tool_calls: None,
1278            tools: vec![open_ai::ToolDefinition::Function {
1279                function: FunctionDefinition {
1280                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1281                    description: Some(tool_description),
1282                    parameters: Some(tool_schema),
1283                },
1284            }],
1285            prompt_cache_key: None,
1286            reasoning_effort: None,
1287        };
1288
1289        cx.spawn(async move |this, cx| {
1290            log::trace!("Sending search planning request");
1291            let response =
1292                Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1293            let mut response = Self::handle_api_response(&this, response, cx)?;
1294            log::trace!("Got search planning response");
1295
1296            let choice = response
1297                .choices
1298                .pop()
1299                .context("No choices in retrieval response")?;
1300            let open_ai::RequestMessage::Assistant {
1301                content: _,
1302                tool_calls,
1303            } = choice.message
1304            else {
1305                anyhow::bail!("Retrieval response didn't include an assistant message");
1306            };
1307
1308            let mut queries: Vec<SearchToolQuery> = Vec::new();
1309            for tool_call in tool_calls {
1310                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1311                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1312                    log::warn!(
1313                        "Context retrieval response tried to call an unknown tool: {}",
1314                        function.name
1315                    );
1316
1317                    continue;
1318                }
1319
1320                let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1321                queries.extend(input.queries);
1322            }
1323
1324            if let Some(debug_tx) = &debug_tx {
1325                debug_tx
1326                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1327                        ZetaSearchQueryDebugInfo {
1328                            project: project.clone(),
1329                            timestamp: Instant::now(),
1330                            search_queries: queries.clone(),
1331                        },
1332                    ))
1333                    .ok();
1334            }
1335
1336            log::trace!("Running retrieval search: {queries:#?}");
1337
1338            let related_excerpts_result =
1339                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1340
1341            log::trace!("Search queries executed");
1342
1343            if let Some(debug_tx) = &debug_tx {
1344                debug_tx
1345                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1346                        ZetaContextRetrievalDebugInfo {
1347                            project: project.clone(),
1348                            timestamp: Instant::now(),
1349                        },
1350                    ))
1351                    .ok();
1352            }
1353
1354            this.update(cx, |this, _cx| {
1355                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1356                    return Ok(());
1357                };
1358                zeta_project.refresh_context_task.take();
1359                if let Some(debug_tx) = &this.debug_tx {
1360                    debug_tx
1361                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1362                            ZetaContextRetrievalDebugInfo {
1363                                project,
1364                                timestamp: Instant::now(),
1365                            },
1366                        ))
1367                        .ok();
1368                }
1369                match related_excerpts_result {
1370                    Ok(excerpts) => {
1371                        zeta_project.context = Some(excerpts);
1372                        Ok(())
1373                    }
1374                    Err(error) => Err(error),
1375                }
1376            })?
1377        })
1378    }
1379
1380    fn gather_nearby_diagnostics(
1381        cursor_offset: usize,
1382        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1383        snapshot: &BufferSnapshot,
1384        max_diagnostics_bytes: usize,
1385    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1386        // TODO: Could make this more efficient
1387        let mut diagnostic_groups = Vec::new();
1388        for (language_server_id, diagnostics) in diagnostic_sets {
1389            let mut groups = Vec::new();
1390            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1391            diagnostic_groups.extend(
1392                groups
1393                    .into_iter()
1394                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1395            );
1396        }
1397
1398        // sort by proximity to cursor
1399        diagnostic_groups.sort_by_key(|group| {
1400            let range = &group.entries[group.primary_ix].range;
1401            if range.start >= cursor_offset {
1402                range.start - cursor_offset
1403            } else if cursor_offset >= range.end {
1404                cursor_offset - range.end
1405            } else {
1406                (cursor_offset - range.start).min(range.end - cursor_offset)
1407            }
1408        });
1409
1410        let mut results = Vec::new();
1411        let mut diagnostic_groups_truncated = false;
1412        let mut diagnostics_byte_count = 0;
1413        for group in diagnostic_groups {
1414            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1415            diagnostics_byte_count += raw_value.get().len();
1416            if diagnostics_byte_count > max_diagnostics_bytes {
1417                diagnostic_groups_truncated = true;
1418                break;
1419            }
1420            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1421        }
1422
1423        (results, diagnostic_groups_truncated)
1424    }
1425
1426    // TODO: Dedupe with similar code in request_prediction?
1427    pub fn cloud_request_for_zeta_cli(
1428        &mut self,
1429        project: &Entity<Project>,
1430        buffer: &Entity<Buffer>,
1431        position: language::Anchor,
1432        cx: &mut Context<Self>,
1433    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1434        let project_state = self.projects.get(&project.entity_id());
1435
1436        let index_state = project_state.map(|state| {
1437            state
1438                .syntax_index
1439                .read_with(cx, |index, _cx| index.state().clone())
1440        });
1441        let options = self.options.clone();
1442        let snapshot = buffer.read(cx).snapshot();
1443        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1444            return Task::ready(Err(anyhow!("No file path for excerpt")));
1445        };
1446        let worktree_snapshots = project
1447            .read(cx)
1448            .worktrees(cx)
1449            .map(|worktree| worktree.read(cx).snapshot())
1450            .collect::<Vec<_>>();
1451
1452        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1453            let mut path = f.worktree.read(cx).absolutize(&f.path);
1454            if path.pop() { Some(path) } else { None }
1455        });
1456
1457        cx.background_spawn(async move {
1458            let index_state = if let Some(index_state) = index_state {
1459                Some(index_state.lock_owned().await)
1460            } else {
1461                None
1462            };
1463
1464            let cursor_point = position.to_point(&snapshot);
1465
1466            let debug_info = true;
1467            EditPredictionContext::gather_context(
1468                cursor_point,
1469                &snapshot,
1470                parent_abs_path.as_deref(),
1471                match &options.context {
1472                    ContextMode::Agentic(_) => {
1473                        // TODO
1474                        panic!("Llm mode not supported in zeta cli yet");
1475                    }
1476                    ContextMode::Syntax(edit_prediction_context_options) => {
1477                        edit_prediction_context_options
1478                    }
1479                },
1480                index_state.as_deref(),
1481            )
1482            .context("Failed to select excerpt")
1483            .map(|context| {
1484                make_syntax_context_cloud_request(
1485                    excerpt_path.into(),
1486                    context,
1487                    // TODO pass everything
1488                    Vec::new(),
1489                    false,
1490                    Vec::new(),
1491                    false,
1492                    None,
1493                    debug_info,
1494                    &worktree_snapshots,
1495                    index_state.as_deref(),
1496                    Some(options.max_prompt_bytes),
1497                    options.prompt_format,
1498                )
1499            })
1500        })
1501    }
1502
1503    pub fn wait_for_initial_indexing(
1504        &mut self,
1505        project: &Entity<Project>,
1506        cx: &mut App,
1507    ) -> Task<Result<()>> {
1508        let zeta_project = self.get_or_init_zeta_project(project, cx);
1509        zeta_project
1510            .syntax_index
1511            .read(cx)
1512            .wait_for_initial_file_indexing(cx)
1513    }
1514}
1515
1516pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1517    let choice = res.choices.pop()?;
1518    let output_text = match choice.message {
1519        open_ai::RequestMessage::Assistant {
1520            content: Some(open_ai::MessageContent::Plain(content)),
1521            ..
1522        } => content,
1523        open_ai::RequestMessage::Assistant {
1524            content: Some(open_ai::MessageContent::Multipart(mut content)),
1525            ..
1526        } => {
1527            if content.is_empty() {
1528                log::error!("No output from Baseten completion response");
1529                return None;
1530            }
1531
1532            match content.remove(0) {
1533                open_ai::MessagePart::Text { text } => text,
1534                open_ai::MessagePart::Image { .. } => {
1535                    log::error!("Expected text, got an image");
1536                    return None;
1537                }
1538            }
1539        }
1540        _ => {
1541            log::error!("Invalid response message: {:?}", choice.message);
1542            return None;
1543        }
1544    };
1545    Some(output_text)
1546}
1547
1548#[derive(Error, Debug)]
1549#[error(
1550    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1551)]
1552pub struct ZedUpdateRequiredError {
1553    minimum_version: SemanticVersion,
1554}
1555
1556fn make_syntax_context_cloud_request(
1557    excerpt_path: Arc<Path>,
1558    context: EditPredictionContext,
1559    events: Vec<predict_edits_v3::Event>,
1560    can_collect_data: bool,
1561    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1562    diagnostic_groups_truncated: bool,
1563    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1564    debug_info: bool,
1565    worktrees: &Vec<worktree::Snapshot>,
1566    index_state: Option<&SyntaxIndexState>,
1567    prompt_max_bytes: Option<usize>,
1568    prompt_format: PromptFormat,
1569) -> predict_edits_v3::PredictEditsRequest {
1570    let mut signatures = Vec::new();
1571    let mut declaration_to_signature_index = HashMap::default();
1572    let mut referenced_declarations = Vec::new();
1573
1574    for snippet in context.declarations {
1575        let project_entry_id = snippet.declaration.project_entry_id();
1576        let Some(path) = worktrees.iter().find_map(|worktree| {
1577            worktree.entry_for_id(project_entry_id).map(|entry| {
1578                let mut full_path = RelPathBuf::new();
1579                full_path.push(worktree.root_name());
1580                full_path.push(&entry.path);
1581                full_path
1582            })
1583        }) else {
1584            continue;
1585        };
1586
1587        let parent_index = index_state.and_then(|index_state| {
1588            snippet.declaration.parent().and_then(|parent| {
1589                add_signature(
1590                    parent,
1591                    &mut declaration_to_signature_index,
1592                    &mut signatures,
1593                    index_state,
1594                )
1595            })
1596        });
1597
1598        let (text, text_is_truncated) = snippet.declaration.item_text();
1599        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1600            path: path.as_std_path().into(),
1601            text: text.into(),
1602            range: snippet.declaration.item_line_range(),
1603            text_is_truncated,
1604            signature_range: snippet.declaration.signature_range_in_item_text(),
1605            parent_index,
1606            signature_score: snippet.score(DeclarationStyle::Signature),
1607            declaration_score: snippet.score(DeclarationStyle::Declaration),
1608            score_components: snippet.components,
1609        });
1610    }
1611
1612    let excerpt_parent = index_state.and_then(|index_state| {
1613        context
1614            .excerpt
1615            .parent_declarations
1616            .last()
1617            .and_then(|(parent, _)| {
1618                add_signature(
1619                    *parent,
1620                    &mut declaration_to_signature_index,
1621                    &mut signatures,
1622                    index_state,
1623                )
1624            })
1625    });
1626
1627    predict_edits_v3::PredictEditsRequest {
1628        excerpt_path,
1629        excerpt: context.excerpt_text.body,
1630        excerpt_line_range: context.excerpt.line_range,
1631        excerpt_range: context.excerpt.range,
1632        cursor_point: predict_edits_v3::Point {
1633            line: predict_edits_v3::Line(context.cursor_point.row),
1634            column: context.cursor_point.column,
1635        },
1636        referenced_declarations,
1637        included_files: vec![],
1638        signatures,
1639        excerpt_parent,
1640        events,
1641        can_collect_data,
1642        diagnostic_groups,
1643        diagnostic_groups_truncated,
1644        git_info,
1645        debug_info,
1646        prompt_max_bytes,
1647        prompt_format,
1648    }
1649}
1650
1651fn add_signature(
1652    declaration_id: DeclarationId,
1653    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1654    signatures: &mut Vec<Signature>,
1655    index: &SyntaxIndexState,
1656) -> Option<usize> {
1657    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1658        return Some(*signature_index);
1659    }
1660    let Some(parent_declaration) = index.declaration(declaration_id) else {
1661        log::error!("bug: missing parent declaration");
1662        return None;
1663    };
1664    let parent_index = parent_declaration.parent().and_then(|parent| {
1665        add_signature(parent, declaration_to_signature_index, signatures, index)
1666    });
1667    let (text, text_is_truncated) = parent_declaration.signature_text();
1668    let signature_index = signatures.len();
1669    signatures.push(Signature {
1670        text: text.into(),
1671        text_is_truncated,
1672        parent_index,
1673        range: parent_declaration.signature_line_range(),
1674    });
1675    declaration_to_signature_index.insert(declaration_id, signature_index);
1676    Some(signature_index)
1677}
1678
1679#[cfg(test)]
1680mod tests {
1681    use std::{path::Path, sync::Arc};
1682
1683    use client::UserStore;
1684    use clock::FakeSystemClock;
1685    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1686    use futures::{
1687        AsyncReadExt, StreamExt,
1688        channel::{mpsc, oneshot},
1689    };
1690    use gpui::{
1691        Entity, TestAppContext,
1692        http_client::{FakeHttpClient, Response},
1693        prelude::*,
1694    };
1695    use indoc::indoc;
1696    use language::OffsetRangeExt as _;
1697    use open_ai::Usage;
1698    use pretty_assertions::{assert_eq, assert_matches};
1699    use project::{FakeFs, Project};
1700    use serde_json::json;
1701    use settings::SettingsStore;
1702    use util::path;
1703    use uuid::Uuid;
1704
1705    use crate::{BufferEditPrediction, Zeta};
1706
1707    #[gpui::test]
1708    async fn test_current_state(cx: &mut TestAppContext) {
1709        let (zeta, mut req_rx) = init_test(cx);
1710        let fs = FakeFs::new(cx.executor());
1711        fs.insert_tree(
1712            "/root",
1713            json!({
1714                "1.txt": "Hello!\nHow\nBye\n",
1715                "2.txt": "Hola!\nComo\nAdios\n"
1716            }),
1717        )
1718        .await;
1719        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1720
1721        zeta.update(cx, |zeta, cx| {
1722            zeta.register_project(&project, cx);
1723        });
1724
1725        let buffer1 = project
1726            .update(cx, |project, cx| {
1727                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1728                project.open_buffer(path, cx)
1729            })
1730            .await
1731            .unwrap();
1732        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1733        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1734
1735        // Prediction for current file
1736
1737        let prediction_task = zeta.update(cx, |zeta, cx| {
1738            zeta.refresh_prediction(&project, &buffer1, position, cx)
1739        });
1740        let (_request, respond_tx) = req_rx.next().await.unwrap();
1741
1742        respond_tx
1743            .send(model_response(indoc! {r"
1744                --- a/root/1.txt
1745                +++ b/root/1.txt
1746                @@ ... @@
1747                 Hello!
1748                -How
1749                +How are you?
1750                 Bye
1751            "}))
1752            .unwrap();
1753        prediction_task.await.unwrap();
1754
1755        zeta.read_with(cx, |zeta, cx| {
1756            let prediction = zeta
1757                .current_prediction_for_buffer(&buffer1, &project, cx)
1758                .unwrap();
1759            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1760        });
1761
1762        // Context refresh
1763        let refresh_task = zeta.update(cx, |zeta, cx| {
1764            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1765        });
1766        let (_request, respond_tx) = req_rx.next().await.unwrap();
1767        respond_tx
1768            .send(open_ai::Response {
1769                id: Uuid::new_v4().to_string(),
1770                object: "response".into(),
1771                created: 0,
1772                model: "model".into(),
1773                choices: vec![open_ai::Choice {
1774                    index: 0,
1775                    message: open_ai::RequestMessage::Assistant {
1776                        content: None,
1777                        tool_calls: vec![open_ai::ToolCall {
1778                            id: "search".into(),
1779                            content: open_ai::ToolCallContent::Function {
1780                                function: open_ai::FunctionContent {
1781                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1782                                        .to_string(),
1783                                    arguments: serde_json::to_string(&SearchToolInput {
1784                                        queries: Box::new([SearchToolQuery {
1785                                            glob: "root/2.txt".to_string(),
1786                                            syntax_node: vec![],
1787                                            content: Some(".".into()),
1788                                        }]),
1789                                    })
1790                                    .unwrap(),
1791                                },
1792                            },
1793                        }],
1794                    },
1795                    finish_reason: None,
1796                }],
1797                usage: Usage {
1798                    prompt_tokens: 0,
1799                    completion_tokens: 0,
1800                    total_tokens: 0,
1801                },
1802            })
1803            .unwrap();
1804        refresh_task.await.unwrap();
1805
1806        zeta.update(cx, |zeta, _cx| {
1807            zeta.discard_current_prediction(&project);
1808        });
1809
1810        // Prediction for another file
1811        let prediction_task = zeta.update(cx, |zeta, cx| {
1812            zeta.refresh_prediction(&project, &buffer1, position, cx)
1813        });
1814        let (_request, respond_tx) = req_rx.next().await.unwrap();
1815        respond_tx
1816            .send(model_response(indoc! {r#"
1817                --- a/root/2.txt
1818                +++ b/root/2.txt
1819                 Hola!
1820                -Como
1821                +Como estas?
1822                 Adios
1823            "#}))
1824            .unwrap();
1825        prediction_task.await.unwrap();
1826        zeta.read_with(cx, |zeta, cx| {
1827            let prediction = zeta
1828                .current_prediction_for_buffer(&buffer1, &project, cx)
1829                .unwrap();
1830            assert_matches!(
1831                prediction,
1832                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1833            );
1834        });
1835
1836        let buffer2 = project
1837            .update(cx, |project, cx| {
1838                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1839                project.open_buffer(path, cx)
1840            })
1841            .await
1842            .unwrap();
1843
1844        zeta.read_with(cx, |zeta, cx| {
1845            let prediction = zeta
1846                .current_prediction_for_buffer(&buffer2, &project, cx)
1847                .unwrap();
1848            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1849        });
1850    }
1851
1852    #[gpui::test]
1853    async fn test_simple_request(cx: &mut TestAppContext) {
1854        let (zeta, mut req_rx) = init_test(cx);
1855        let fs = FakeFs::new(cx.executor());
1856        fs.insert_tree(
1857            "/root",
1858            json!({
1859                "foo.md":  "Hello!\nHow\nBye\n"
1860            }),
1861        )
1862        .await;
1863        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1864
1865        let buffer = project
1866            .update(cx, |project, cx| {
1867                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1868                project.open_buffer(path, cx)
1869            })
1870            .await
1871            .unwrap();
1872        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1873        let position = snapshot.anchor_before(language::Point::new(1, 3));
1874
1875        let prediction_task = zeta.update(cx, |zeta, cx| {
1876            zeta.request_prediction(&project, &buffer, position, cx)
1877        });
1878
1879        let (_, respond_tx) = req_rx.next().await.unwrap();
1880
1881        // TODO Put back when we have a structured request again
1882        // assert_eq!(
1883        //     request.excerpt_path.as_ref(),
1884        //     Path::new(path!("root/foo.md"))
1885        // );
1886        // assert_eq!(
1887        //     request.cursor_point,
1888        //     Point {
1889        //         line: Line(1),
1890        //         column: 3
1891        //     }
1892        // );
1893
1894        respond_tx
1895            .send(model_response(indoc! { r"
1896                --- a/root/foo.md
1897                +++ b/root/foo.md
1898                @@ ... @@
1899                 Hello!
1900                -How
1901                +How are you?
1902                 Bye
1903            "}))
1904            .unwrap();
1905
1906        let prediction = prediction_task.await.unwrap().unwrap();
1907
1908        assert_eq!(prediction.edits.len(), 1);
1909        assert_eq!(
1910            prediction.edits[0].0.to_point(&snapshot).start,
1911            language::Point::new(1, 3)
1912        );
1913        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1914    }
1915
1916    #[gpui::test]
1917    async fn test_request_events(cx: &mut TestAppContext) {
1918        let (zeta, mut req_rx) = init_test(cx);
1919        let fs = FakeFs::new(cx.executor());
1920        fs.insert_tree(
1921            "/root",
1922            json!({
1923                "foo.md": "Hello!\n\nBye\n"
1924            }),
1925        )
1926        .await;
1927        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1928
1929        let buffer = project
1930            .update(cx, |project, cx| {
1931                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1932                project.open_buffer(path, cx)
1933            })
1934            .await
1935            .unwrap();
1936
1937        zeta.update(cx, |zeta, cx| {
1938            zeta.register_buffer(&buffer, &project, cx);
1939        });
1940
1941        buffer.update(cx, |buffer, cx| {
1942            buffer.edit(vec![(7..7, "How")], None, cx);
1943        });
1944
1945        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1946        let position = snapshot.anchor_before(language::Point::new(1, 3));
1947
1948        let prediction_task = zeta.update(cx, |zeta, cx| {
1949            zeta.request_prediction(&project, &buffer, position, cx)
1950        });
1951
1952        let (request, respond_tx) = req_rx.next().await.unwrap();
1953
1954        let prompt = prompt_from_request(&request);
1955        assert!(
1956            prompt.contains(indoc! {"
1957            --- a/root/foo.md
1958            +++ b/root/foo.md
1959            @@ -1,3 +1,3 @@
1960             Hello!
1961            -
1962            +How
1963             Bye
1964        "}),
1965            "{prompt}"
1966        );
1967
1968        respond_tx
1969            .send(model_response(indoc! {r#"
1970                --- a/root/foo.md
1971                +++ b/root/foo.md
1972                @@ ... @@
1973                 Hello!
1974                -How
1975                +How are you?
1976                 Bye
1977            "#}))
1978            .unwrap();
1979
1980        let prediction = prediction_task.await.unwrap().unwrap();
1981
1982        assert_eq!(prediction.edits.len(), 1);
1983        assert_eq!(
1984            prediction.edits[0].0.to_point(&snapshot).start,
1985            language::Point::new(1, 3)
1986        );
1987        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1988    }
1989
1990    // Skipped until we start including diagnostics in prompt
1991    // #[gpui::test]
1992    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1993    //     let (zeta, mut req_rx) = init_test(cx);
1994    //     let fs = FakeFs::new(cx.executor());
1995    //     fs.insert_tree(
1996    //         "/root",
1997    //         json!({
1998    //             "foo.md": "Hello!\nBye"
1999    //         }),
2000    //     )
2001    //     .await;
2002    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2003
2004    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2005    //     let diagnostic = lsp::Diagnostic {
2006    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2007    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2008    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2009    //         ..Default::default()
2010    //     };
2011
2012    //     project.update(cx, |project, cx| {
2013    //         project.lsp_store().update(cx, |lsp_store, cx| {
2014    //             // Create some diagnostics
2015    //             lsp_store
2016    //                 .update_diagnostics(
2017    //                     LanguageServerId(0),
2018    //                     lsp::PublishDiagnosticsParams {
2019    //                         uri: path_to_buffer_uri.clone(),
2020    //                         diagnostics: vec![diagnostic],
2021    //                         version: None,
2022    //                     },
2023    //                     None,
2024    //                     language::DiagnosticSourceKind::Pushed,
2025    //                     &[],
2026    //                     cx,
2027    //                 )
2028    //                 .unwrap();
2029    //         });
2030    //     });
2031
2032    //     let buffer = project
2033    //         .update(cx, |project, cx| {
2034    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2035    //             project.open_buffer(path, cx)
2036    //         })
2037    //         .await
2038    //         .unwrap();
2039
2040    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2041    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2042
2043    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2044    //         zeta.request_prediction(&project, &buffer, position, cx)
2045    //     });
2046
2047    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2048
2049    //     assert_eq!(request.diagnostic_groups.len(), 1);
2050    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2051    //         .unwrap();
2052    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2053    //     assert_eq!(
2054    //         value,
2055    //         json!({
2056    //             "entries": [{
2057    //                 "range": {
2058    //                     "start": 8,
2059    //                     "end": 10
2060    //                 },
2061    //                 "diagnostic": {
2062    //                     "source": null,
2063    //                     "code": null,
2064    //                     "code_description": null,
2065    //                     "severity": 1,
2066    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2067    //                     "markdown": null,
2068    //                     "group_id": 0,
2069    //                     "is_primary": true,
2070    //                     "is_disk_based": false,
2071    //                     "is_unnecessary": false,
2072    //                     "source_kind": "Pushed",
2073    //                     "data": null,
2074    //                     "underline": true
2075    //                 }
2076    //             }],
2077    //             "primary_ix": 0
2078    //         })
2079    //     );
2080    // }
2081
2082    fn model_response(text: &str) -> open_ai::Response {
2083        open_ai::Response {
2084            id: Uuid::new_v4().to_string(),
2085            object: "response".into(),
2086            created: 0,
2087            model: "model".into(),
2088            choices: vec![open_ai::Choice {
2089                index: 0,
2090                message: open_ai::RequestMessage::Assistant {
2091                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2092                    tool_calls: vec![],
2093                },
2094                finish_reason: None,
2095            }],
2096            usage: Usage {
2097                prompt_tokens: 0,
2098                completion_tokens: 0,
2099                total_tokens: 0,
2100            },
2101        }
2102    }
2103
2104    fn prompt_from_request(request: &open_ai::Request) -> &str {
2105        assert_eq!(request.messages.len(), 1);
2106        let open_ai::RequestMessage::User {
2107            content: open_ai::MessageContent::Plain(content),
2108            ..
2109        } = &request.messages[0]
2110        else {
2111            panic!(
2112                "Request does not have single user message of type Plain. {:#?}",
2113                request
2114            );
2115        };
2116        content
2117    }
2118
2119    fn init_test(
2120        cx: &mut TestAppContext,
2121    ) -> (
2122        Entity<Zeta>,
2123        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2124    ) {
2125        cx.update(move |cx| {
2126            let settings_store = SettingsStore::test(cx);
2127            cx.set_global(settings_store);
2128            zlog::init_test();
2129
2130            let (req_tx, req_rx) = mpsc::unbounded();
2131
2132            let http_client = FakeHttpClient::create({
2133                move |req| {
2134                    let uri = req.uri().path().to_string();
2135                    let mut body = req.into_body();
2136                    let req_tx = req_tx.clone();
2137                    async move {
2138                        let resp = match uri.as_str() {
2139                            "/client/llm_tokens" => serde_json::to_string(&json!({
2140                                "token": "test"
2141                            }))
2142                            .unwrap(),
2143                            "/predict_edits/raw" => {
2144                                let mut buf = Vec::new();
2145                                body.read_to_end(&mut buf).await.ok();
2146                                let req = serde_json::from_slice(&buf).unwrap();
2147
2148                                let (res_tx, res_rx) = oneshot::channel();
2149                                req_tx.unbounded_send((req, res_tx)).unwrap();
2150                                serde_json::to_string(&res_rx.await?).unwrap()
2151                            }
2152                            _ => {
2153                                panic!("Unexpected path: {}", uri)
2154                            }
2155                        };
2156
2157                        Ok(Response::builder().body(resp.into()).unwrap())
2158                    }
2159                }
2160            });
2161
2162            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2163            client.cloud_client().set_credentials(1, "test".into());
2164
2165            language_model::init(client.clone(), cx);
2166
2167            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2168            let zeta = Zeta::global(&client, &user_store, cx);
2169
2170            (zeta, req_rx)
2171        })
2172    }
2173}