zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
   9use edit_prediction_context::{
  10    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  11    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::{mpsc, oneshot};
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  21use language::{BufferSnapshot, TextBufferSnapshot};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::EditPrediction;
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  47    use_imports: true,
  48    excerpt: EditPredictionExcerptOptions {
  49        max_bytes: 512,
  50        min_bytes: 128,
  51        target_before_cursor_over_total_bytes: 0.5,
  52    },
  53    score: EditPredictionScoreOptions {
  54        omit_excerpt_overlaps: true,
  55        prefilter_score_ratio: 0.5,
  56    },
  57};
  58
  59pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  60    context: DEFAULT_CONTEXT_OPTIONS,
  61    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  62    max_diagnostic_bytes: 2048,
  63    prompt_format: PromptFormat::DEFAULT,
  64    file_indexing_parallelism: 1,
  65};
  66
  67#[derive(Clone)]
  68struct ZetaGlobal(Entity<Zeta>);
  69
  70impl Global for ZetaGlobal {}
  71
  72pub struct Zeta {
  73    client: Arc<Client>,
  74    user_store: Entity<UserStore>,
  75    llm_token: LlmApiToken,
  76    _llm_token_subscription: Subscription,
  77    projects: HashMap<EntityId, ZetaProject>,
  78    options: ZetaOptions,
  79    update_required: bool,
  80    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  81}
  82
  83#[derive(Debug, Clone, PartialEq)]
  84pub struct ZetaOptions {
  85    pub context: EditPredictionContextOptions,
  86    pub max_prompt_bytes: usize,
  87    pub max_diagnostic_bytes: usize,
  88    pub prompt_format: predict_edits_v3::PromptFormat,
  89    pub file_indexing_parallelism: usize,
  90}
  91
  92pub struct PredictionDebugInfo {
  93    pub context: EditPredictionContext,
  94    pub retrieval_time: TimeDelta,
  95    pub buffer: WeakEntity<Buffer>,
  96    pub position: language::Anchor,
  97    pub local_prompt: Result<String, String>,
  98    pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
  99}
 100
 101pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 102
 103struct ZetaProject {
 104    syntax_index: Entity<SyntaxIndex>,
 105    events: VecDeque<Event>,
 106    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 107    current_prediction: Option<CurrentEditPrediction>,
 108}
 109
 110#[derive(Clone)]
 111struct CurrentEditPrediction {
 112    pub requested_by_buffer_id: EntityId,
 113    pub prediction: EditPrediction,
 114}
 115
 116impl CurrentEditPrediction {
 117    fn should_replace_prediction(
 118        &self,
 119        old_prediction: &Self,
 120        snapshot: &TextBufferSnapshot,
 121    ) -> bool {
 122        if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
 123            return true;
 124        }
 125
 126        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 127            return true;
 128        };
 129
 130        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 131            return false;
 132        };
 133        if old_edits.len() == 1 && new_edits.len() == 1 {
 134            let (old_range, old_text) = &old_edits[0];
 135            let (new_range, new_text) = &new_edits[0];
 136            new_range == old_range && new_text.starts_with(old_text)
 137        } else {
 138            true
 139        }
 140    }
 141}
 142
 143/// A prediction from the perspective of a buffer.
 144#[derive(Debug)]
 145enum BufferEditPrediction<'a> {
 146    Local { prediction: &'a EditPrediction },
 147    Jump { prediction: &'a EditPrediction },
 148}
 149
 150struct RegisteredBuffer {
 151    snapshot: BufferSnapshot,
 152    _subscriptions: [gpui::Subscription; 2],
 153}
 154
 155#[derive(Clone)]
 156pub enum Event {
 157    BufferChange {
 158        old_snapshot: BufferSnapshot,
 159        new_snapshot: BufferSnapshot,
 160        timestamp: Instant,
 161    },
 162}
 163
 164impl Zeta {
 165    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 166        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 167    }
 168
 169    pub fn global(
 170        client: &Arc<Client>,
 171        user_store: &Entity<UserStore>,
 172        cx: &mut App,
 173    ) -> Entity<Self> {
 174        cx.try_global::<ZetaGlobal>()
 175            .map(|global| global.0.clone())
 176            .unwrap_or_else(|| {
 177                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 178                cx.set_global(ZetaGlobal(zeta.clone()));
 179                zeta
 180            })
 181    }
 182
 183    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 184        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 185
 186        Self {
 187            projects: HashMap::new(),
 188            client,
 189            user_store,
 190            options: DEFAULT_OPTIONS,
 191            llm_token: LlmApiToken::default(),
 192            _llm_token_subscription: cx.subscribe(
 193                &refresh_llm_token_listener,
 194                |this, _listener, _event, cx| {
 195                    let client = this.client.clone();
 196                    let llm_token = this.llm_token.clone();
 197                    cx.spawn(async move |_this, _cx| {
 198                        llm_token.refresh(&client).await?;
 199                        anyhow::Ok(())
 200                    })
 201                    .detach_and_log_err(cx);
 202                },
 203            ),
 204            update_required: false,
 205            debug_tx: None,
 206        }
 207    }
 208
 209    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 210        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 211        self.debug_tx = Some(debug_watch_tx);
 212        debug_watch_rx
 213    }
 214
 215    pub fn options(&self) -> &ZetaOptions {
 216        &self.options
 217    }
 218
 219    pub fn set_options(&mut self, options: ZetaOptions) {
 220        self.options = options;
 221    }
 222
 223    pub fn clear_history(&mut self) {
 224        for zeta_project in self.projects.values_mut() {
 225            zeta_project.events.clear();
 226        }
 227    }
 228
 229    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 230        self.user_store.read(cx).edit_prediction_usage()
 231    }
 232
 233    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 234        self.get_or_init_zeta_project(project, cx);
 235    }
 236
 237    pub fn register_buffer(
 238        &mut self,
 239        buffer: &Entity<Buffer>,
 240        project: &Entity<Project>,
 241        cx: &mut Context<Self>,
 242    ) {
 243        let zeta_project = self.get_or_init_zeta_project(project, cx);
 244        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 245    }
 246
 247    fn get_or_init_zeta_project(
 248        &mut self,
 249        project: &Entity<Project>,
 250        cx: &mut App,
 251    ) -> &mut ZetaProject {
 252        self.projects
 253            .entry(project.entity_id())
 254            .or_insert_with(|| ZetaProject {
 255                syntax_index: cx.new(|cx| {
 256                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 257                }),
 258                events: VecDeque::new(),
 259                registered_buffers: HashMap::new(),
 260                current_prediction: None,
 261            })
 262    }
 263
 264    fn register_buffer_impl<'a>(
 265        zeta_project: &'a mut ZetaProject,
 266        buffer: &Entity<Buffer>,
 267        project: &Entity<Project>,
 268        cx: &mut Context<Self>,
 269    ) -> &'a mut RegisteredBuffer {
 270        let buffer_id = buffer.entity_id();
 271        match zeta_project.registered_buffers.entry(buffer_id) {
 272            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 273            hash_map::Entry::Vacant(entry) => {
 274                let snapshot = buffer.read(cx).snapshot();
 275                let project_entity_id = project.entity_id();
 276                entry.insert(RegisteredBuffer {
 277                    snapshot,
 278                    _subscriptions: [
 279                        cx.subscribe(buffer, {
 280                            let project = project.downgrade();
 281                            move |this, buffer, event, cx| {
 282                                if let language::BufferEvent::Edited = event
 283                                    && let Some(project) = project.upgrade()
 284                                {
 285                                    this.report_changes_for_buffer(&buffer, &project, cx);
 286                                }
 287                            }
 288                        }),
 289                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 290                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 291                            else {
 292                                return;
 293                            };
 294                            zeta_project.registered_buffers.remove(&buffer_id);
 295                        }),
 296                    ],
 297                })
 298            }
 299        }
 300    }
 301
 302    fn report_changes_for_buffer(
 303        &mut self,
 304        buffer: &Entity<Buffer>,
 305        project: &Entity<Project>,
 306        cx: &mut Context<Self>,
 307    ) -> BufferSnapshot {
 308        let zeta_project = self.get_or_init_zeta_project(project, cx);
 309        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 310
 311        let new_snapshot = buffer.read(cx).snapshot();
 312        if new_snapshot.version != registered_buffer.snapshot.version {
 313            let old_snapshot =
 314                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 315            Self::push_event(
 316                zeta_project,
 317                Event::BufferChange {
 318                    old_snapshot,
 319                    new_snapshot: new_snapshot.clone(),
 320                    timestamp: Instant::now(),
 321                },
 322            );
 323        }
 324
 325        new_snapshot
 326    }
 327
 328    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 329        let events = &mut zeta_project.events;
 330
 331        if let Some(Event::BufferChange {
 332            new_snapshot: last_new_snapshot,
 333            timestamp: last_timestamp,
 334            ..
 335        }) = events.back_mut()
 336        {
 337            // Coalesce edits for the same buffer when they happen one after the other.
 338            let Event::BufferChange {
 339                old_snapshot,
 340                new_snapshot,
 341                timestamp,
 342            } = &event;
 343
 344            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 345                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 346                && old_snapshot.version == last_new_snapshot.version
 347            {
 348                *last_new_snapshot = new_snapshot.clone();
 349                *last_timestamp = *timestamp;
 350                return;
 351            }
 352        }
 353
 354        if events.len() >= MAX_EVENT_COUNT {
 355            // These are halved instead of popping to improve prompt caching.
 356            events.drain(..MAX_EVENT_COUNT / 2);
 357        }
 358
 359        events.push_back(event);
 360    }
 361
 362    fn current_prediction_for_buffer(
 363        &self,
 364        buffer: &Entity<Buffer>,
 365        project: &Entity<Project>,
 366        cx: &App,
 367    ) -> Option<BufferEditPrediction<'_>> {
 368        let project_state = self.projects.get(&project.entity_id())?;
 369
 370        let CurrentEditPrediction {
 371            requested_by_buffer_id,
 372            prediction,
 373        } = project_state.current_prediction.as_ref()?;
 374
 375        if prediction.targets_buffer(buffer.read(cx), cx) {
 376            Some(BufferEditPrediction::Local { prediction })
 377        } else if *requested_by_buffer_id == buffer.entity_id() {
 378            Some(BufferEditPrediction::Jump { prediction })
 379        } else {
 380            None
 381        }
 382    }
 383
 384    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
 385        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 386            project_state.current_prediction.take();
 387        };
 388        // TODO report accepted
 389    }
 390
 391    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 392        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 393            project_state.current_prediction.take();
 394        };
 395    }
 396
 397    pub fn refresh_prediction(
 398        &mut self,
 399        project: &Entity<Project>,
 400        buffer: &Entity<Buffer>,
 401        position: language::Anchor,
 402        cx: &mut Context<Self>,
 403    ) -> Task<Result<()>> {
 404        let request_task = self.request_prediction(project, buffer, position, cx);
 405        let buffer = buffer.clone();
 406        let project = project.clone();
 407
 408        cx.spawn(async move |this, cx| {
 409            if let Some(prediction) = request_task.await? {
 410                this.update(cx, |this, cx| {
 411                    let project_state = this
 412                        .projects
 413                        .get_mut(&project.entity_id())
 414                        .context("Project not found")?;
 415
 416                    let new_prediction = CurrentEditPrediction {
 417                        requested_by_buffer_id: buffer.entity_id(),
 418                        prediction: prediction,
 419                    };
 420
 421                    if project_state
 422                        .current_prediction
 423                        .as_ref()
 424                        .is_none_or(|old_prediction| {
 425                            new_prediction
 426                                .should_replace_prediction(&old_prediction, buffer.read(cx))
 427                        })
 428                    {
 429                        project_state.current_prediction = Some(new_prediction);
 430                    }
 431                    anyhow::Ok(())
 432                })??;
 433            }
 434            Ok(())
 435        })
 436    }
 437
 438    fn request_prediction(
 439        &mut self,
 440        project: &Entity<Project>,
 441        buffer: &Entity<Buffer>,
 442        position: language::Anchor,
 443        cx: &mut Context<Self>,
 444    ) -> Task<Result<Option<EditPrediction>>> {
 445        let project_state = self.projects.get(&project.entity_id());
 446
 447        let index_state = project_state.map(|state| {
 448            state
 449                .syntax_index
 450                .read_with(cx, |index, _cx| index.state().clone())
 451        });
 452        let options = self.options.clone();
 453        let snapshot = buffer.read(cx).snapshot();
 454        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 455            return Task::ready(Err(anyhow!("No file path for excerpt")));
 456        };
 457        let client = self.client.clone();
 458        let llm_token = self.llm_token.clone();
 459        let app_version = AppVersion::global(cx);
 460        let worktree_snapshots = project
 461            .read(cx)
 462            .worktrees(cx)
 463            .map(|worktree| worktree.read(cx).snapshot())
 464            .collect::<Vec<_>>();
 465        let debug_tx = self.debug_tx.clone();
 466
 467        let events = project_state
 468            .map(|state| {
 469                state
 470                    .events
 471                    .iter()
 472                    .filter_map(|event| match event {
 473                        Event::BufferChange {
 474                            old_snapshot,
 475                            new_snapshot,
 476                            ..
 477                        } => {
 478                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 479
 480                            let old_path = old_snapshot.file().and_then(|f| {
 481                                let old_path = f.full_path(cx);
 482                                if Some(&old_path) != path.as_ref() {
 483                                    Some(old_path)
 484                                } else {
 485                                    None
 486                                }
 487                            });
 488
 489                            // TODO [zeta2] move to bg?
 490                            let diff =
 491                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 492
 493                            if path == old_path && diff.is_empty() {
 494                                None
 495                            } else {
 496                                Some(predict_edits_v3::Event::BufferChange {
 497                                    old_path,
 498                                    path,
 499                                    diff,
 500                                    //todo: Actually detect if this edit was predicted or not
 501                                    predicted: false,
 502                                })
 503                            }
 504                        }
 505                    })
 506                    .collect::<Vec<_>>()
 507            })
 508            .unwrap_or_default();
 509
 510        let diagnostics = snapshot.diagnostic_sets().clone();
 511
 512        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 513            let mut path = f.worktree.read(cx).absolutize(&f.path);
 514            if path.pop() { Some(path) } else { None }
 515        });
 516
 517        let request_task = cx.background_spawn({
 518            let snapshot = snapshot.clone();
 519            let buffer = buffer.clone();
 520            async move {
 521                let index_state = if let Some(index_state) = index_state {
 522                    Some(index_state.lock_owned().await)
 523                } else {
 524                    None
 525                };
 526
 527                let cursor_offset = position.to_offset(&snapshot);
 528                let cursor_point = cursor_offset.to_point(&snapshot);
 529
 530                let before_retrieval = chrono::Utc::now();
 531
 532                let Some(context) = EditPredictionContext::gather_context(
 533                    cursor_point,
 534                    &snapshot,
 535                    parent_abs_path.as_deref(),
 536                    &options.context,
 537                    index_state.as_deref(),
 538                ) else {
 539                    return Ok(None);
 540                };
 541
 542                let retrieval_time = chrono::Utc::now() - before_retrieval;
 543
 544                let (diagnostic_groups, diagnostic_groups_truncated) =
 545                    Self::gather_nearby_diagnostics(
 546                        cursor_offset,
 547                        &diagnostics,
 548                        &snapshot,
 549                        options.max_diagnostic_bytes,
 550                    );
 551
 552                let debug_context = debug_tx.map(|tx| (tx, context.clone()));
 553
 554                let request = make_cloud_request(
 555                    excerpt_path,
 556                    context,
 557                    events,
 558                    // TODO data collection
 559                    false,
 560                    diagnostic_groups,
 561                    diagnostic_groups_truncated,
 562                    None,
 563                    debug_context.is_some(),
 564                    &worktree_snapshots,
 565                    index_state.as_deref(),
 566                    Some(options.max_prompt_bytes),
 567                    options.prompt_format,
 568                );
 569
 570                let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
 571                    let (response_tx, response_rx) = oneshot::channel();
 572
 573                    let local_prompt = PlannedPrompt::populate(&request)
 574                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 575                        .map_err(|err| err.to_string());
 576
 577                    debug_tx
 578                        .unbounded_send(PredictionDebugInfo {
 579                            context,
 580                            retrieval_time,
 581                            buffer: buffer.downgrade(),
 582                            local_prompt,
 583                            position,
 584                            response_rx,
 585                        })
 586                        .ok();
 587                    Some(response_tx)
 588                } else {
 589                    None
 590                };
 591
 592                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 593                    if let Some(debug_response_tx) = debug_response_tx {
 594                        debug_response_tx
 595                            .send(Err("Request skipped".to_string()))
 596                            .ok();
 597                    }
 598                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 599                }
 600
 601                let response = Self::perform_request(client, llm_token, app_version, request).await;
 602
 603                if let Some(debug_response_tx) = debug_response_tx {
 604                    debug_response_tx
 605                        .send(response.as_ref().map_err(|err| err.to_string()).and_then(
 606                            |response| match some_or_debug_panic(response.0.debug_info.clone()) {
 607                                Some(debug_info) => Ok(debug_info),
 608                                None => Err("Missing debug info".to_string()),
 609                            },
 610                        ))
 611                        .ok();
 612                }
 613
 614                anyhow::Ok(Some(response?))
 615            }
 616        });
 617
 618        let buffer = buffer.clone();
 619
 620        cx.spawn({
 621            let project = project.clone();
 622            async move |this, cx| {
 623                match request_task.await {
 624                    Ok(Some((response, usage))) => {
 625                        if let Some(usage) = usage {
 626                            this.update(cx, |this, cx| {
 627                                this.user_store.update(cx, |user_store, cx| {
 628                                    user_store.update_edit_prediction_usage(usage, cx);
 629                                });
 630                            })
 631                            .ok();
 632                        }
 633
 634                        let prediction = EditPrediction::from_response(
 635                            response, &snapshot, &buffer, &project, cx,
 636                        )
 637                        .await;
 638
 639                        // TODO telemetry: duration, etc
 640                        Ok(prediction)
 641                    }
 642                    Ok(None) => Ok(None),
 643                    Err(err) => {
 644                        if err.is::<ZedUpdateRequiredError>() {
 645                            cx.update(|cx| {
 646                                this.update(cx, |this, _cx| {
 647                                    this.update_required = true;
 648                                })
 649                                .ok();
 650
 651                                let error_message: SharedString = err.to_string().into();
 652                                show_app_notification(
 653                                    NotificationId::unique::<ZedUpdateRequiredError>(),
 654                                    cx,
 655                                    move |cx| {
 656                                        cx.new(|cx| {
 657                                            ErrorMessagePrompt::new(error_message.clone(), cx)
 658                                                .with_link_button(
 659                                                    "Update Zed",
 660                                                    "https://zed.dev/releases",
 661                                                )
 662                                        })
 663                                    },
 664                                );
 665                            })
 666                            .ok();
 667                        }
 668
 669                        Err(err)
 670                    }
 671                }
 672            }
 673        })
 674    }
 675
 676    async fn perform_request(
 677        client: Arc<Client>,
 678        llm_token: LlmApiToken,
 679        app_version: SemanticVersion,
 680        request: predict_edits_v3::PredictEditsRequest,
 681    ) -> Result<(
 682        predict_edits_v3::PredictEditsResponse,
 683        Option<EditPredictionUsage>,
 684    )> {
 685        let http_client = client.http_client();
 686        let mut token = llm_token.acquire(&client).await?;
 687        let mut did_retry = false;
 688
 689        loop {
 690            let request_builder = http_client::Request::builder().method(Method::POST);
 691            let request_builder =
 692                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 693                    request_builder.uri(predict_edits_url)
 694                } else {
 695                    request_builder.uri(
 696                        http_client
 697                            .build_zed_llm_url("/predict_edits/v3", &[])?
 698                            .as_ref(),
 699                    )
 700                };
 701            let request = request_builder
 702                .header("Content-Type", "application/json")
 703                .header("Authorization", format!("Bearer {}", token))
 704                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 705                .body(serde_json::to_string(&request)?.into())?;
 706
 707            let mut response = http_client.send(request).await?;
 708
 709            if let Some(minimum_required_version) = response
 710                .headers()
 711                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 712                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 713            {
 714                anyhow::ensure!(
 715                    app_version >= minimum_required_version,
 716                    ZedUpdateRequiredError {
 717                        minimum_version: minimum_required_version
 718                    }
 719                );
 720            }
 721
 722            if response.status().is_success() {
 723                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 724
 725                let mut body = Vec::new();
 726                response.body_mut().read_to_end(&mut body).await?;
 727                return Ok((serde_json::from_slice(&body)?, usage));
 728            } else if !did_retry
 729                && response
 730                    .headers()
 731                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 732                    .is_some()
 733            {
 734                did_retry = true;
 735                token = llm_token.refresh(&client).await?;
 736            } else {
 737                let mut body = String::new();
 738                response.body_mut().read_to_string(&mut body).await?;
 739                anyhow::bail!(
 740                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 741                    response.status(),
 742                    body
 743                );
 744            }
 745        }
 746    }
 747
 748    fn gather_nearby_diagnostics(
 749        cursor_offset: usize,
 750        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 751        snapshot: &BufferSnapshot,
 752        max_diagnostics_bytes: usize,
 753    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 754        // TODO: Could make this more efficient
 755        let mut diagnostic_groups = Vec::new();
 756        for (language_server_id, diagnostics) in diagnostic_sets {
 757            let mut groups = Vec::new();
 758            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 759            diagnostic_groups.extend(
 760                groups
 761                    .into_iter()
 762                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 763            );
 764        }
 765
 766        // sort by proximity to cursor
 767        diagnostic_groups.sort_by_key(|group| {
 768            let range = &group.entries[group.primary_ix].range;
 769            if range.start >= cursor_offset {
 770                range.start - cursor_offset
 771            } else if cursor_offset >= range.end {
 772                cursor_offset - range.end
 773            } else {
 774                (cursor_offset - range.start).min(range.end - cursor_offset)
 775            }
 776        });
 777
 778        let mut results = Vec::new();
 779        let mut diagnostic_groups_truncated = false;
 780        let mut diagnostics_byte_count = 0;
 781        for group in diagnostic_groups {
 782            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 783            diagnostics_byte_count += raw_value.get().len();
 784            if diagnostics_byte_count > max_diagnostics_bytes {
 785                diagnostic_groups_truncated = true;
 786                break;
 787            }
 788            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 789        }
 790
 791        (results, diagnostic_groups_truncated)
 792    }
 793
 794    // TODO: Dedupe with similar code in request_prediction?
 795    pub fn cloud_request_for_zeta_cli(
 796        &mut self,
 797        project: &Entity<Project>,
 798        buffer: &Entity<Buffer>,
 799        position: language::Anchor,
 800        cx: &mut Context<Self>,
 801    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 802        let project_state = self.projects.get(&project.entity_id());
 803
 804        let index_state = project_state.map(|state| {
 805            state
 806                .syntax_index
 807                .read_with(cx, |index, _cx| index.state().clone())
 808        });
 809        let options = self.options.clone();
 810        let snapshot = buffer.read(cx).snapshot();
 811        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 812            return Task::ready(Err(anyhow!("No file path for excerpt")));
 813        };
 814        let worktree_snapshots = project
 815            .read(cx)
 816            .worktrees(cx)
 817            .map(|worktree| worktree.read(cx).snapshot())
 818            .collect::<Vec<_>>();
 819
 820        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 821            let mut path = f.worktree.read(cx).absolutize(&f.path);
 822            if path.pop() { Some(path) } else { None }
 823        });
 824
 825        cx.background_spawn(async move {
 826            let index_state = if let Some(index_state) = index_state {
 827                Some(index_state.lock_owned().await)
 828            } else {
 829                None
 830            };
 831
 832            let cursor_point = position.to_point(&snapshot);
 833
 834            let debug_info = true;
 835            EditPredictionContext::gather_context(
 836                cursor_point,
 837                &snapshot,
 838                parent_abs_path.as_deref(),
 839                &options.context,
 840                index_state.as_deref(),
 841            )
 842            .context("Failed to select excerpt")
 843            .map(|context| {
 844                make_cloud_request(
 845                    excerpt_path.into(),
 846                    context,
 847                    // TODO pass everything
 848                    Vec::new(),
 849                    false,
 850                    Vec::new(),
 851                    false,
 852                    None,
 853                    debug_info,
 854                    &worktree_snapshots,
 855                    index_state.as_deref(),
 856                    Some(options.max_prompt_bytes),
 857                    options.prompt_format,
 858                )
 859            })
 860        })
 861    }
 862
 863    pub fn wait_for_initial_indexing(
 864        &mut self,
 865        project: &Entity<Project>,
 866        cx: &mut App,
 867    ) -> Task<Result<()>> {
 868        let zeta_project = self.get_or_init_zeta_project(project, cx);
 869        zeta_project
 870            .syntax_index
 871            .read(cx)
 872            .wait_for_initial_file_indexing(cx)
 873    }
 874}
 875
 876#[derive(Error, Debug)]
 877#[error(
 878    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 879)]
 880pub struct ZedUpdateRequiredError {
 881    minimum_version: SemanticVersion,
 882}
 883
 884fn make_cloud_request(
 885    excerpt_path: Arc<Path>,
 886    context: EditPredictionContext,
 887    events: Vec<predict_edits_v3::Event>,
 888    can_collect_data: bool,
 889    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 890    diagnostic_groups_truncated: bool,
 891    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 892    debug_info: bool,
 893    worktrees: &Vec<worktree::Snapshot>,
 894    index_state: Option<&SyntaxIndexState>,
 895    prompt_max_bytes: Option<usize>,
 896    prompt_format: PromptFormat,
 897) -> predict_edits_v3::PredictEditsRequest {
 898    let mut signatures = Vec::new();
 899    let mut declaration_to_signature_index = HashMap::default();
 900    let mut referenced_declarations = Vec::new();
 901
 902    for snippet in context.declarations {
 903        let project_entry_id = snippet.declaration.project_entry_id();
 904        let Some(path) = worktrees.iter().find_map(|worktree| {
 905            worktree.entry_for_id(project_entry_id).map(|entry| {
 906                let mut full_path = RelPathBuf::new();
 907                full_path.push(worktree.root_name());
 908                full_path.push(&entry.path);
 909                full_path
 910            })
 911        }) else {
 912            continue;
 913        };
 914
 915        let parent_index = index_state.and_then(|index_state| {
 916            snippet.declaration.parent().and_then(|parent| {
 917                add_signature(
 918                    parent,
 919                    &mut declaration_to_signature_index,
 920                    &mut signatures,
 921                    index_state,
 922                )
 923            })
 924        });
 925
 926        let (text, text_is_truncated) = snippet.declaration.item_text();
 927        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 928            path: path.as_std_path().into(),
 929            text: text.into(),
 930            range: snippet.declaration.item_range(),
 931            text_is_truncated,
 932            signature_range: snippet.declaration.signature_range_in_item_text(),
 933            parent_index,
 934            signature_score: snippet.score(DeclarationStyle::Signature),
 935            declaration_score: snippet.score(DeclarationStyle::Declaration),
 936            score_components: snippet.components,
 937        });
 938    }
 939
 940    let excerpt_parent = index_state.and_then(|index_state| {
 941        context
 942            .excerpt
 943            .parent_declarations
 944            .last()
 945            .and_then(|(parent, _)| {
 946                add_signature(
 947                    *parent,
 948                    &mut declaration_to_signature_index,
 949                    &mut signatures,
 950                    index_state,
 951                )
 952            })
 953    });
 954
 955    predict_edits_v3::PredictEditsRequest {
 956        excerpt_path,
 957        excerpt: context.excerpt_text.body,
 958        excerpt_range: context.excerpt.range,
 959        cursor_offset: context.cursor_offset_in_excerpt,
 960        referenced_declarations,
 961        signatures,
 962        excerpt_parent,
 963        events,
 964        can_collect_data,
 965        diagnostic_groups,
 966        diagnostic_groups_truncated,
 967        git_info,
 968        debug_info,
 969        prompt_max_bytes,
 970        prompt_format,
 971    }
 972}
 973
 974fn add_signature(
 975    declaration_id: DeclarationId,
 976    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 977    signatures: &mut Vec<Signature>,
 978    index: &SyntaxIndexState,
 979) -> Option<usize> {
 980    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 981        return Some(*signature_index);
 982    }
 983    let Some(parent_declaration) = index.declaration(declaration_id) else {
 984        log::error!("bug: missing parent declaration");
 985        return None;
 986    };
 987    let parent_index = parent_declaration.parent().and_then(|parent| {
 988        add_signature(parent, declaration_to_signature_index, signatures, index)
 989    });
 990    let (text, text_is_truncated) = parent_declaration.signature_text();
 991    let signature_index = signatures.len();
 992    signatures.push(Signature {
 993        text: text.into(),
 994        text_is_truncated,
 995        parent_index,
 996        range: parent_declaration.signature_range(),
 997    });
 998    declaration_to_signature_index.insert(declaration_id, signature_index);
 999    Some(signature_index)
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004    use std::{
1005        path::{Path, PathBuf},
1006        sync::Arc,
1007    };
1008
1009    use client::UserStore;
1010    use clock::FakeSystemClock;
1011    use cloud_llm_client::predict_edits_v3;
1012    use futures::{
1013        AsyncReadExt, StreamExt,
1014        channel::{mpsc, oneshot},
1015    };
1016    use gpui::{
1017        Entity, TestAppContext,
1018        http_client::{FakeHttpClient, Response},
1019        prelude::*,
1020    };
1021    use indoc::indoc;
1022    use language::{LanguageServerId, OffsetRangeExt as _};
1023    use pretty_assertions::{assert_eq, assert_matches};
1024    use project::{FakeFs, Project};
1025    use serde_json::json;
1026    use settings::SettingsStore;
1027    use util::path;
1028    use uuid::Uuid;
1029
1030    use crate::{BufferEditPrediction, Zeta};
1031
1032    #[gpui::test]
1033    async fn test_current_state(cx: &mut TestAppContext) {
1034        let (zeta, mut req_rx) = init_test(cx);
1035        let fs = FakeFs::new(cx.executor());
1036        fs.insert_tree(
1037            "/root",
1038            json!({
1039                "1.txt": "Hello!\nHow\nBye",
1040                "2.txt": "Hola!\nComo\nAdios"
1041            }),
1042        )
1043        .await;
1044        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1045
1046        zeta.update(cx, |zeta, cx| {
1047            zeta.register_project(&project, cx);
1048        });
1049
1050        let buffer1 = project
1051            .update(cx, |project, cx| {
1052                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1053                project.open_buffer(path, cx)
1054            })
1055            .await
1056            .unwrap();
1057        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1058        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1059
1060        // Prediction for current file
1061
1062        let prediction_task = zeta.update(cx, |zeta, cx| {
1063            zeta.refresh_prediction(&project, &buffer1, position, cx)
1064        });
1065        let (_request, respond_tx) = req_rx.next().await.unwrap();
1066        respond_tx
1067            .send(predict_edits_v3::PredictEditsResponse {
1068                request_id: Uuid::new_v4(),
1069                edits: vec![predict_edits_v3::Edit {
1070                    path: Path::new(path!("root/1.txt")).into(),
1071                    range: 0..snapshot1.len(),
1072                    content: "Hello!\nHow are you?\nBye".into(),
1073                }],
1074                debug_info: None,
1075            })
1076            .unwrap();
1077        prediction_task.await.unwrap();
1078
1079        zeta.read_with(cx, |zeta, cx| {
1080            let prediction = zeta
1081                .current_prediction_for_buffer(&buffer1, &project, cx)
1082                .unwrap();
1083            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1084        });
1085
1086        // Prediction for another file
1087
1088        let prediction_task = zeta.update(cx, |zeta, cx| {
1089            zeta.refresh_prediction(&project, &buffer1, position, cx)
1090        });
1091        let (_request, respond_tx) = req_rx.next().await.unwrap();
1092        respond_tx
1093            .send(predict_edits_v3::PredictEditsResponse {
1094                request_id: Uuid::new_v4(),
1095                edits: vec![predict_edits_v3::Edit {
1096                    path: Path::new(path!("root/2.txt")).into(),
1097                    range: 0..snapshot1.len(),
1098                    content: "Hola!\nComo estas?\nAdios".into(),
1099                }],
1100                debug_info: None,
1101            })
1102            .unwrap();
1103        prediction_task.await.unwrap();
1104
1105        zeta.read_with(cx, |zeta, cx| {
1106            let prediction = zeta
1107                .current_prediction_for_buffer(&buffer1, &project, cx)
1108                .unwrap();
1109            assert_matches!(
1110                prediction,
1111                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1112            );
1113        });
1114
1115        let buffer2 = project
1116            .update(cx, |project, cx| {
1117                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1118                project.open_buffer(path, cx)
1119            })
1120            .await
1121            .unwrap();
1122
1123        zeta.read_with(cx, |zeta, cx| {
1124            let prediction = zeta
1125                .current_prediction_for_buffer(&buffer2, &project, cx)
1126                .unwrap();
1127            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1128        });
1129    }
1130
1131    #[gpui::test]
1132    async fn test_simple_request(cx: &mut TestAppContext) {
1133        let (zeta, mut req_rx) = init_test(cx);
1134        let fs = FakeFs::new(cx.executor());
1135        fs.insert_tree(
1136            "/root",
1137            json!({
1138                "foo.md":  "Hello!\nHow\nBye"
1139            }),
1140        )
1141        .await;
1142        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1143
1144        let buffer = project
1145            .update(cx, |project, cx| {
1146                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1147                project.open_buffer(path, cx)
1148            })
1149            .await
1150            .unwrap();
1151        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1152        let position = snapshot.anchor_before(language::Point::new(1, 3));
1153
1154        let prediction_task = zeta.update(cx, |zeta, cx| {
1155            zeta.request_prediction(&project, &buffer, position, cx)
1156        });
1157
1158        let (request, respond_tx) = req_rx.next().await.unwrap();
1159        assert_eq!(
1160            request.excerpt_path.as_ref(),
1161            Path::new(path!("root/foo.md"))
1162        );
1163        assert_eq!(request.cursor_offset, 10);
1164
1165        respond_tx
1166            .send(predict_edits_v3::PredictEditsResponse {
1167                request_id: Uuid::new_v4(),
1168                edits: vec![predict_edits_v3::Edit {
1169                    path: Path::new(path!("root/foo.md")).into(),
1170                    range: 0..snapshot.len(),
1171                    content: "Hello!\nHow are you?\nBye".into(),
1172                }],
1173                debug_info: None,
1174            })
1175            .unwrap();
1176
1177        let prediction = prediction_task.await.unwrap().unwrap();
1178
1179        assert_eq!(prediction.edits.len(), 1);
1180        assert_eq!(
1181            prediction.edits[0].0.to_point(&snapshot).start,
1182            language::Point::new(1, 3)
1183        );
1184        assert_eq!(prediction.edits[0].1, " are you?");
1185    }
1186
1187    #[gpui::test]
1188    async fn test_request_events(cx: &mut TestAppContext) {
1189        let (zeta, mut req_rx) = init_test(cx);
1190        let fs = FakeFs::new(cx.executor());
1191        fs.insert_tree(
1192            "/root",
1193            json!({
1194                "foo.md": "Hello!\n\nBye"
1195            }),
1196        )
1197        .await;
1198        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1199
1200        let buffer = project
1201            .update(cx, |project, cx| {
1202                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1203                project.open_buffer(path, cx)
1204            })
1205            .await
1206            .unwrap();
1207
1208        zeta.update(cx, |zeta, cx| {
1209            zeta.register_buffer(&buffer, &project, cx);
1210        });
1211
1212        buffer.update(cx, |buffer, cx| {
1213            buffer.edit(vec![(7..7, "How")], None, cx);
1214        });
1215
1216        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1217        let position = snapshot.anchor_before(language::Point::new(1, 3));
1218
1219        let prediction_task = zeta.update(cx, |zeta, cx| {
1220            zeta.request_prediction(&project, &buffer, position, cx)
1221        });
1222
1223        let (request, respond_tx) = req_rx.next().await.unwrap();
1224
1225        assert_eq!(request.events.len(), 1);
1226        assert_eq!(
1227            request.events[0],
1228            predict_edits_v3::Event::BufferChange {
1229                path: Some(PathBuf::from(path!("root/foo.md"))),
1230                old_path: None,
1231                diff: indoc! {"
1232                        @@ -1,3 +1,3 @@
1233                         Hello!
1234                        -
1235                        +How
1236                         Bye
1237                    "}
1238                .to_string(),
1239                predicted: false
1240            }
1241        );
1242
1243        respond_tx
1244            .send(predict_edits_v3::PredictEditsResponse {
1245                request_id: Uuid::new_v4(),
1246                edits: vec![predict_edits_v3::Edit {
1247                    path: Path::new(path!("root/foo.md")).into(),
1248                    range: 0..snapshot.len(),
1249                    content: "Hello!\nHow are you?\nBye".into(),
1250                }],
1251                debug_info: None,
1252            })
1253            .unwrap();
1254
1255        let prediction = prediction_task.await.unwrap().unwrap();
1256
1257        assert_eq!(prediction.edits.len(), 1);
1258        assert_eq!(
1259            prediction.edits[0].0.to_point(&snapshot).start,
1260            language::Point::new(1, 3)
1261        );
1262        assert_eq!(prediction.edits[0].1, " are you?");
1263    }
1264
1265    #[gpui::test]
1266    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1267        let (zeta, mut req_rx) = init_test(cx);
1268        let fs = FakeFs::new(cx.executor());
1269        fs.insert_tree(
1270            "/root",
1271            json!({
1272                "foo.md": "Hello!\nBye"
1273            }),
1274        )
1275        .await;
1276        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1277
1278        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1279        let diagnostic = lsp::Diagnostic {
1280            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1281            severity: Some(lsp::DiagnosticSeverity::ERROR),
1282            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1283            ..Default::default()
1284        };
1285
1286        project.update(cx, |project, cx| {
1287            project.lsp_store().update(cx, |lsp_store, cx| {
1288                // Create some diagnostics
1289                lsp_store
1290                    .update_diagnostics(
1291                        LanguageServerId(0),
1292                        lsp::PublishDiagnosticsParams {
1293                            uri: path_to_buffer_uri.clone(),
1294                            diagnostics: vec![diagnostic],
1295                            version: None,
1296                        },
1297                        None,
1298                        language::DiagnosticSourceKind::Pushed,
1299                        &[],
1300                        cx,
1301                    )
1302                    .unwrap();
1303            });
1304        });
1305
1306        let buffer = project
1307            .update(cx, |project, cx| {
1308                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1309                project.open_buffer(path, cx)
1310            })
1311            .await
1312            .unwrap();
1313
1314        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1315        let position = snapshot.anchor_before(language::Point::new(0, 0));
1316
1317        let _prediction_task = zeta.update(cx, |zeta, cx| {
1318            zeta.request_prediction(&project, &buffer, position, cx)
1319        });
1320
1321        let (request, _respond_tx) = req_rx.next().await.unwrap();
1322
1323        assert_eq!(request.diagnostic_groups.len(), 1);
1324        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1325            .unwrap();
1326        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1327        assert_eq!(
1328            value,
1329            json!({
1330                "entries": [{
1331                    "range": {
1332                        "start": 8,
1333                        "end": 10
1334                    },
1335                    "diagnostic": {
1336                        "source": null,
1337                        "code": null,
1338                        "code_description": null,
1339                        "severity": 1,
1340                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1341                        "markdown": null,
1342                        "group_id": 0,
1343                        "is_primary": true,
1344                        "is_disk_based": false,
1345                        "is_unnecessary": false,
1346                        "source_kind": "Pushed",
1347                        "data": null,
1348                        "underline": true
1349                    }
1350                }],
1351                "primary_ix": 0
1352            })
1353        );
1354    }
1355
1356    fn init_test(
1357        cx: &mut TestAppContext,
1358    ) -> (
1359        Entity<Zeta>,
1360        mpsc::UnboundedReceiver<(
1361            predict_edits_v3::PredictEditsRequest,
1362            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1363        )>,
1364    ) {
1365        cx.update(move |cx| {
1366            let settings_store = SettingsStore::test(cx);
1367            cx.set_global(settings_store);
1368            language::init(cx);
1369            Project::init_settings(cx);
1370
1371            let (req_tx, req_rx) = mpsc::unbounded();
1372
1373            let http_client = FakeHttpClient::create({
1374                move |req| {
1375                    let uri = req.uri().path().to_string();
1376                    let mut body = req.into_body();
1377                    let req_tx = req_tx.clone();
1378                    async move {
1379                        let resp = match uri.as_str() {
1380                            "/client/llm_tokens" => serde_json::to_string(&json!({
1381                                "token": "test"
1382                            }))
1383                            .unwrap(),
1384                            "/predict_edits/v3" => {
1385                                let mut buf = Vec::new();
1386                                body.read_to_end(&mut buf).await.ok();
1387                                let req = serde_json::from_slice(&buf).unwrap();
1388
1389                                let (res_tx, res_rx) = oneshot::channel();
1390                                req_tx.unbounded_send((req, res_tx)).unwrap();
1391                                serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1392                            }
1393                            _ => {
1394                                panic!("Unexpected path: {}", uri)
1395                            }
1396                        };
1397
1398                        Ok(Response::builder().body(resp.into()).unwrap())
1399                    }
1400                }
1401            });
1402
1403            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1404            client.cloud_client().set_credentials(1, "test".into());
1405
1406            language_model::init(client.clone(), cx);
1407
1408            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1409            let zeta = Zeta::global(&client, &user_store, cx);
1410
1411            (zeta, req_rx)
1412        })
1413    }
1414}