zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
   9use edit_prediction_context::{
  10    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  11    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::{mpsc, oneshot};
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::BufferSnapshot;
  21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::EditPrediction;
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  47    use_imports: true,
  48    use_references: false,
  49    excerpt: EditPredictionExcerptOptions {
  50        max_bytes: 512,
  51        min_bytes: 128,
  52        target_before_cursor_over_total_bytes: 0.5,
  53    },
  54    score: EditPredictionScoreOptions {
  55        omit_excerpt_overlaps: true,
  56    },
  57};
  58
  59pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  60    context: DEFAULT_CONTEXT_OPTIONS,
  61    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  62    max_diagnostic_bytes: 2048,
  63    prompt_format: PromptFormat::DEFAULT,
  64    file_indexing_parallelism: 1,
  65};
  66
  67#[derive(Clone)]
  68struct ZetaGlobal(Entity<Zeta>);
  69
  70impl Global for ZetaGlobal {}
  71
  72pub struct Zeta {
  73    client: Arc<Client>,
  74    user_store: Entity<UserStore>,
  75    llm_token: LlmApiToken,
  76    _llm_token_subscription: Subscription,
  77    projects: HashMap<EntityId, ZetaProject>,
  78    options: ZetaOptions,
  79    update_required: bool,
  80    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  81}
  82
  83#[derive(Debug, Clone, PartialEq)]
  84pub struct ZetaOptions {
  85    pub context: EditPredictionContextOptions,
  86    pub max_prompt_bytes: usize,
  87    pub max_diagnostic_bytes: usize,
  88    pub prompt_format: predict_edits_v3::PromptFormat,
  89    pub file_indexing_parallelism: usize,
  90}
  91
  92pub struct PredictionDebugInfo {
  93    pub context: EditPredictionContext,
  94    pub retrieval_time: TimeDelta,
  95    pub buffer: WeakEntity<Buffer>,
  96    pub position: language::Anchor,
  97    pub local_prompt: Result<String, String>,
  98    pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
  99}
 100
 101pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 102
 103struct ZetaProject {
 104    syntax_index: Entity<SyntaxIndex>,
 105    events: VecDeque<Event>,
 106    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 107    current_prediction: Option<CurrentEditPrediction>,
 108}
 109
 110#[derive(Debug, Clone)]
 111struct CurrentEditPrediction {
 112    pub requested_by_buffer_id: EntityId,
 113    pub prediction: EditPrediction,
 114}
 115
 116impl CurrentEditPrediction {
 117    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 118        let Some(new_edits) = self
 119            .prediction
 120            .interpolate(&self.prediction.buffer.read(cx))
 121        else {
 122            return false;
 123        };
 124
 125        if self.prediction.buffer != old_prediction.prediction.buffer {
 126            return true;
 127        }
 128
 129        let Some(old_edits) = old_prediction
 130            .prediction
 131            .interpolate(&old_prediction.prediction.buffer.read(cx))
 132        else {
 133            return true;
 134        };
 135
 136        // This reduces the occurrence of UI thrash from replacing edits
 137        //
 138        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 139        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 140            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 141            && old_edits.len() == 1
 142            && new_edits.len() == 1
 143        {
 144            let (old_range, old_text) = &old_edits[0];
 145            let (new_range, new_text) = &new_edits[0];
 146            new_range == old_range && new_text.starts_with(old_text)
 147        } else {
 148            true
 149        }
 150    }
 151}
 152
 153/// A prediction from the perspective of a buffer.
 154#[derive(Debug)]
 155enum BufferEditPrediction<'a> {
 156    Local { prediction: &'a EditPrediction },
 157    Jump { prediction: &'a EditPrediction },
 158}
 159
 160struct RegisteredBuffer {
 161    snapshot: BufferSnapshot,
 162    _subscriptions: [gpui::Subscription; 2],
 163}
 164
 165#[derive(Clone)]
 166pub enum Event {
 167    BufferChange {
 168        old_snapshot: BufferSnapshot,
 169        new_snapshot: BufferSnapshot,
 170        timestamp: Instant,
 171    },
 172}
 173
 174impl Zeta {
 175    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 176        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 177    }
 178
 179    pub fn global(
 180        client: &Arc<Client>,
 181        user_store: &Entity<UserStore>,
 182        cx: &mut App,
 183    ) -> Entity<Self> {
 184        cx.try_global::<ZetaGlobal>()
 185            .map(|global| global.0.clone())
 186            .unwrap_or_else(|| {
 187                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 188                cx.set_global(ZetaGlobal(zeta.clone()));
 189                zeta
 190            })
 191    }
 192
 193    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 194        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 195
 196        Self {
 197            projects: HashMap::new(),
 198            client,
 199            user_store,
 200            options: DEFAULT_OPTIONS,
 201            llm_token: LlmApiToken::default(),
 202            _llm_token_subscription: cx.subscribe(
 203                &refresh_llm_token_listener,
 204                |this, _listener, _event, cx| {
 205                    let client = this.client.clone();
 206                    let llm_token = this.llm_token.clone();
 207                    cx.spawn(async move |_this, _cx| {
 208                        llm_token.refresh(&client).await?;
 209                        anyhow::Ok(())
 210                    })
 211                    .detach_and_log_err(cx);
 212                },
 213            ),
 214            update_required: false,
 215            debug_tx: None,
 216        }
 217    }
 218
 219    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 220        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 221        self.debug_tx = Some(debug_watch_tx);
 222        debug_watch_rx
 223    }
 224
 225    pub fn options(&self) -> &ZetaOptions {
 226        &self.options
 227    }
 228
 229    pub fn set_options(&mut self, options: ZetaOptions) {
 230        self.options = options;
 231    }
 232
 233    pub fn clear_history(&mut self) {
 234        for zeta_project in self.projects.values_mut() {
 235            zeta_project.events.clear();
 236        }
 237    }
 238
 239    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 240        self.user_store.read(cx).edit_prediction_usage()
 241    }
 242
 243    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 244        self.get_or_init_zeta_project(project, cx);
 245    }
 246
 247    pub fn register_buffer(
 248        &mut self,
 249        buffer: &Entity<Buffer>,
 250        project: &Entity<Project>,
 251        cx: &mut Context<Self>,
 252    ) {
 253        let zeta_project = self.get_or_init_zeta_project(project, cx);
 254        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 255    }
 256
 257    fn get_or_init_zeta_project(
 258        &mut self,
 259        project: &Entity<Project>,
 260        cx: &mut App,
 261    ) -> &mut ZetaProject {
 262        self.projects
 263            .entry(project.entity_id())
 264            .or_insert_with(|| ZetaProject {
 265                syntax_index: cx.new(|cx| {
 266                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 267                }),
 268                events: VecDeque::new(),
 269                registered_buffers: HashMap::new(),
 270                current_prediction: None,
 271            })
 272    }
 273
 274    fn register_buffer_impl<'a>(
 275        zeta_project: &'a mut ZetaProject,
 276        buffer: &Entity<Buffer>,
 277        project: &Entity<Project>,
 278        cx: &mut Context<Self>,
 279    ) -> &'a mut RegisteredBuffer {
 280        let buffer_id = buffer.entity_id();
 281        match zeta_project.registered_buffers.entry(buffer_id) {
 282            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 283            hash_map::Entry::Vacant(entry) => {
 284                let snapshot = buffer.read(cx).snapshot();
 285                let project_entity_id = project.entity_id();
 286                entry.insert(RegisteredBuffer {
 287                    snapshot,
 288                    _subscriptions: [
 289                        cx.subscribe(buffer, {
 290                            let project = project.downgrade();
 291                            move |this, buffer, event, cx| {
 292                                if let language::BufferEvent::Edited = event
 293                                    && let Some(project) = project.upgrade()
 294                                {
 295                                    this.report_changes_for_buffer(&buffer, &project, cx);
 296                                }
 297                            }
 298                        }),
 299                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 300                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 301                            else {
 302                                return;
 303                            };
 304                            zeta_project.registered_buffers.remove(&buffer_id);
 305                        }),
 306                    ],
 307                })
 308            }
 309        }
 310    }
 311
 312    fn report_changes_for_buffer(
 313        &mut self,
 314        buffer: &Entity<Buffer>,
 315        project: &Entity<Project>,
 316        cx: &mut Context<Self>,
 317    ) -> BufferSnapshot {
 318        let zeta_project = self.get_or_init_zeta_project(project, cx);
 319        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 320
 321        let new_snapshot = buffer.read(cx).snapshot();
 322        if new_snapshot.version != registered_buffer.snapshot.version {
 323            let old_snapshot =
 324                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 325            Self::push_event(
 326                zeta_project,
 327                Event::BufferChange {
 328                    old_snapshot,
 329                    new_snapshot: new_snapshot.clone(),
 330                    timestamp: Instant::now(),
 331                },
 332            );
 333        }
 334
 335        new_snapshot
 336    }
 337
 338    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 339        let events = &mut zeta_project.events;
 340
 341        if let Some(Event::BufferChange {
 342            new_snapshot: last_new_snapshot,
 343            timestamp: last_timestamp,
 344            ..
 345        }) = events.back_mut()
 346        {
 347            // Coalesce edits for the same buffer when they happen one after the other.
 348            let Event::BufferChange {
 349                old_snapshot,
 350                new_snapshot,
 351                timestamp,
 352            } = &event;
 353
 354            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 355                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 356                && old_snapshot.version == last_new_snapshot.version
 357            {
 358                *last_new_snapshot = new_snapshot.clone();
 359                *last_timestamp = *timestamp;
 360                return;
 361            }
 362        }
 363
 364        if events.len() >= MAX_EVENT_COUNT {
 365            // These are halved instead of popping to improve prompt caching.
 366            events.drain(..MAX_EVENT_COUNT / 2);
 367        }
 368
 369        events.push_back(event);
 370    }
 371
 372    fn current_prediction_for_buffer(
 373        &self,
 374        buffer: &Entity<Buffer>,
 375        project: &Entity<Project>,
 376        cx: &App,
 377    ) -> Option<BufferEditPrediction<'_>> {
 378        let project_state = self.projects.get(&project.entity_id())?;
 379
 380        let CurrentEditPrediction {
 381            requested_by_buffer_id,
 382            prediction,
 383        } = project_state.current_prediction.as_ref()?;
 384
 385        if prediction.targets_buffer(buffer.read(cx), cx) {
 386            Some(BufferEditPrediction::Local { prediction })
 387        } else if *requested_by_buffer_id == buffer.entity_id() {
 388            Some(BufferEditPrediction::Jump { prediction })
 389        } else {
 390            None
 391        }
 392    }
 393
 394    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
 395        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 396            project_state.current_prediction.take();
 397        };
 398        // TODO report accepted
 399    }
 400
 401    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 402        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 403            project_state.current_prediction.take();
 404        };
 405    }
 406
 407    pub fn refresh_prediction(
 408        &mut self,
 409        project: &Entity<Project>,
 410        buffer: &Entity<Buffer>,
 411        position: language::Anchor,
 412        cx: &mut Context<Self>,
 413    ) -> Task<Result<()>> {
 414        let request_task = self.request_prediction(project, buffer, position, cx);
 415        let buffer = buffer.clone();
 416        let project = project.clone();
 417
 418        cx.spawn(async move |this, cx| {
 419            if let Some(prediction) = request_task.await? {
 420                this.update(cx, |this, cx| {
 421                    let project_state = this
 422                        .projects
 423                        .get_mut(&project.entity_id())
 424                        .context("Project not found")?;
 425
 426                    let new_prediction = CurrentEditPrediction {
 427                        requested_by_buffer_id: buffer.entity_id(),
 428                        prediction: prediction,
 429                    };
 430
 431                    if project_state
 432                        .current_prediction
 433                        .as_ref()
 434                        .is_none_or(|old_prediction| {
 435                            new_prediction.should_replace_prediction(&old_prediction, cx)
 436                        })
 437                    {
 438                        project_state.current_prediction = Some(new_prediction);
 439                    }
 440                    anyhow::Ok(())
 441                })??;
 442            }
 443            Ok(())
 444        })
 445    }
 446
 447    fn request_prediction(
 448        &mut self,
 449        project: &Entity<Project>,
 450        buffer: &Entity<Buffer>,
 451        position: language::Anchor,
 452        cx: &mut Context<Self>,
 453    ) -> Task<Result<Option<EditPrediction>>> {
 454        let project_state = self.projects.get(&project.entity_id());
 455
 456        let index_state = project_state.map(|state| {
 457            state
 458                .syntax_index
 459                .read_with(cx, |index, _cx| index.state().clone())
 460        });
 461        let options = self.options.clone();
 462        let snapshot = buffer.read(cx).snapshot();
 463        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 464            return Task::ready(Err(anyhow!("No file path for excerpt")));
 465        };
 466        let client = self.client.clone();
 467        let llm_token = self.llm_token.clone();
 468        let app_version = AppVersion::global(cx);
 469        let worktree_snapshots = project
 470            .read(cx)
 471            .worktrees(cx)
 472            .map(|worktree| worktree.read(cx).snapshot())
 473            .collect::<Vec<_>>();
 474        let debug_tx = self.debug_tx.clone();
 475
 476        let events = project_state
 477            .map(|state| {
 478                state
 479                    .events
 480                    .iter()
 481                    .filter_map(|event| match event {
 482                        Event::BufferChange {
 483                            old_snapshot,
 484                            new_snapshot,
 485                            ..
 486                        } => {
 487                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 488
 489                            let old_path = old_snapshot.file().and_then(|f| {
 490                                let old_path = f.full_path(cx);
 491                                if Some(&old_path) != path.as_ref() {
 492                                    Some(old_path)
 493                                } else {
 494                                    None
 495                                }
 496                            });
 497
 498                            // TODO [zeta2] move to bg?
 499                            let diff =
 500                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 501
 502                            if path == old_path && diff.is_empty() {
 503                                None
 504                            } else {
 505                                Some(predict_edits_v3::Event::BufferChange {
 506                                    old_path,
 507                                    path,
 508                                    diff,
 509                                    //todo: Actually detect if this edit was predicted or not
 510                                    predicted: false,
 511                                })
 512                            }
 513                        }
 514                    })
 515                    .collect::<Vec<_>>()
 516            })
 517            .unwrap_or_default();
 518
 519        let diagnostics = snapshot.diagnostic_sets().clone();
 520
 521        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 522            let mut path = f.worktree.read(cx).absolutize(&f.path);
 523            if path.pop() { Some(path) } else { None }
 524        });
 525
 526        let request_task = cx.background_spawn({
 527            let snapshot = snapshot.clone();
 528            let buffer = buffer.clone();
 529            async move {
 530                let index_state = if let Some(index_state) = index_state {
 531                    Some(index_state.lock_owned().await)
 532                } else {
 533                    None
 534                };
 535
 536                let cursor_offset = position.to_offset(&snapshot);
 537                let cursor_point = cursor_offset.to_point(&snapshot);
 538
 539                let before_retrieval = chrono::Utc::now();
 540
 541                let Some(context) = EditPredictionContext::gather_context(
 542                    cursor_point,
 543                    &snapshot,
 544                    parent_abs_path.as_deref(),
 545                    &options.context,
 546                    index_state.as_deref(),
 547                ) else {
 548                    return Ok(None);
 549                };
 550
 551                let retrieval_time = chrono::Utc::now() - before_retrieval;
 552
 553                let (diagnostic_groups, diagnostic_groups_truncated) =
 554                    Self::gather_nearby_diagnostics(
 555                        cursor_offset,
 556                        &diagnostics,
 557                        &snapshot,
 558                        options.max_diagnostic_bytes,
 559                    );
 560
 561                let debug_context = debug_tx.map(|tx| (tx, context.clone()));
 562
 563                let request = make_cloud_request(
 564                    excerpt_path,
 565                    context,
 566                    events,
 567                    // TODO data collection
 568                    false,
 569                    diagnostic_groups,
 570                    diagnostic_groups_truncated,
 571                    None,
 572                    debug_context.is_some(),
 573                    &worktree_snapshots,
 574                    index_state.as_deref(),
 575                    Some(options.max_prompt_bytes),
 576                    options.prompt_format,
 577                );
 578
 579                let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
 580                    let (response_tx, response_rx) = oneshot::channel();
 581
 582                    let local_prompt = PlannedPrompt::populate(&request)
 583                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 584                        .map_err(|err| err.to_string());
 585
 586                    debug_tx
 587                        .unbounded_send(PredictionDebugInfo {
 588                            context,
 589                            retrieval_time,
 590                            buffer: buffer.downgrade(),
 591                            local_prompt,
 592                            position,
 593                            response_rx,
 594                        })
 595                        .ok();
 596                    Some(response_tx)
 597                } else {
 598                    None
 599                };
 600
 601                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 602                    if let Some(debug_response_tx) = debug_response_tx {
 603                        debug_response_tx
 604                            .send(Err("Request skipped".to_string()))
 605                            .ok();
 606                    }
 607                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 608                }
 609
 610                let response = Self::perform_request(client, llm_token, app_version, request).await;
 611
 612                if let Some(debug_response_tx) = debug_response_tx {
 613                    debug_response_tx
 614                        .send(response.as_ref().map_err(|err| err.to_string()).and_then(
 615                            |response| match some_or_debug_panic(response.0.debug_info.clone()) {
 616                                Some(debug_info) => Ok(debug_info),
 617                                None => Err("Missing debug info".to_string()),
 618                            },
 619                        ))
 620                        .ok();
 621                }
 622
 623                anyhow::Ok(Some(response?))
 624            }
 625        });
 626
 627        let buffer = buffer.clone();
 628
 629        cx.spawn({
 630            let project = project.clone();
 631            async move |this, cx| {
 632                match request_task.await {
 633                    Ok(Some((response, usage))) => {
 634                        if let Some(usage) = usage {
 635                            this.update(cx, |this, cx| {
 636                                this.user_store.update(cx, |user_store, cx| {
 637                                    user_store.update_edit_prediction_usage(usage, cx);
 638                                });
 639                            })
 640                            .ok();
 641                        }
 642
 643                        let prediction = EditPrediction::from_response(
 644                            response, &snapshot, &buffer, &project, cx,
 645                        )
 646                        .await;
 647
 648                        // TODO telemetry: duration, etc
 649                        Ok(prediction)
 650                    }
 651                    Ok(None) => Ok(None),
 652                    Err(err) => {
 653                        if err.is::<ZedUpdateRequiredError>() {
 654                            cx.update(|cx| {
 655                                this.update(cx, |this, _cx| {
 656                                    this.update_required = true;
 657                                })
 658                                .ok();
 659
 660                                let error_message: SharedString = err.to_string().into();
 661                                show_app_notification(
 662                                    NotificationId::unique::<ZedUpdateRequiredError>(),
 663                                    cx,
 664                                    move |cx| {
 665                                        cx.new(|cx| {
 666                                            ErrorMessagePrompt::new(error_message.clone(), cx)
 667                                                .with_link_button(
 668                                                    "Update Zed",
 669                                                    "https://zed.dev/releases",
 670                                                )
 671                                        })
 672                                    },
 673                                );
 674                            })
 675                            .ok();
 676                        }
 677
 678                        Err(err)
 679                    }
 680                }
 681            }
 682        })
 683    }
 684
 685    async fn perform_request(
 686        client: Arc<Client>,
 687        llm_token: LlmApiToken,
 688        app_version: SemanticVersion,
 689        request: predict_edits_v3::PredictEditsRequest,
 690    ) -> Result<(
 691        predict_edits_v3::PredictEditsResponse,
 692        Option<EditPredictionUsage>,
 693    )> {
 694        let http_client = client.http_client();
 695        let mut token = llm_token.acquire(&client).await?;
 696        let mut did_retry = false;
 697
 698        loop {
 699            let request_builder = http_client::Request::builder().method(Method::POST);
 700            let request_builder =
 701                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 702                    request_builder.uri(predict_edits_url)
 703                } else {
 704                    request_builder.uri(
 705                        http_client
 706                            .build_zed_llm_url("/predict_edits/v3", &[])?
 707                            .as_ref(),
 708                    )
 709                };
 710            let request = request_builder
 711                .header("Content-Type", "application/json")
 712                .header("Authorization", format!("Bearer {}", token))
 713                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 714                .body(serde_json::to_string(&request)?.into())?;
 715
 716            let mut response = http_client.send(request).await?;
 717
 718            if let Some(minimum_required_version) = response
 719                .headers()
 720                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 721                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 722            {
 723                anyhow::ensure!(
 724                    app_version >= minimum_required_version,
 725                    ZedUpdateRequiredError {
 726                        minimum_version: minimum_required_version
 727                    }
 728                );
 729            }
 730
 731            if response.status().is_success() {
 732                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 733
 734                let mut body = Vec::new();
 735                response.body_mut().read_to_end(&mut body).await?;
 736                return Ok((serde_json::from_slice(&body)?, usage));
 737            } else if !did_retry
 738                && response
 739                    .headers()
 740                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 741                    .is_some()
 742            {
 743                did_retry = true;
 744                token = llm_token.refresh(&client).await?;
 745            } else {
 746                let mut body = String::new();
 747                response.body_mut().read_to_string(&mut body).await?;
 748                anyhow::bail!(
 749                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 750                    response.status(),
 751                    body
 752                );
 753            }
 754        }
 755    }
 756
 757    fn gather_nearby_diagnostics(
 758        cursor_offset: usize,
 759        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 760        snapshot: &BufferSnapshot,
 761        max_diagnostics_bytes: usize,
 762    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 763        // TODO: Could make this more efficient
 764        let mut diagnostic_groups = Vec::new();
 765        for (language_server_id, diagnostics) in diagnostic_sets {
 766            let mut groups = Vec::new();
 767            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 768            diagnostic_groups.extend(
 769                groups
 770                    .into_iter()
 771                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 772            );
 773        }
 774
 775        // sort by proximity to cursor
 776        diagnostic_groups.sort_by_key(|group| {
 777            let range = &group.entries[group.primary_ix].range;
 778            if range.start >= cursor_offset {
 779                range.start - cursor_offset
 780            } else if cursor_offset >= range.end {
 781                cursor_offset - range.end
 782            } else {
 783                (cursor_offset - range.start).min(range.end - cursor_offset)
 784            }
 785        });
 786
 787        let mut results = Vec::new();
 788        let mut diagnostic_groups_truncated = false;
 789        let mut diagnostics_byte_count = 0;
 790        for group in diagnostic_groups {
 791            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 792            diagnostics_byte_count += raw_value.get().len();
 793            if diagnostics_byte_count > max_diagnostics_bytes {
 794                diagnostic_groups_truncated = true;
 795                break;
 796            }
 797            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 798        }
 799
 800        (results, diagnostic_groups_truncated)
 801    }
 802
 803    // TODO: Dedupe with similar code in request_prediction?
 804    pub fn cloud_request_for_zeta_cli(
 805        &mut self,
 806        project: &Entity<Project>,
 807        buffer: &Entity<Buffer>,
 808        position: language::Anchor,
 809        cx: &mut Context<Self>,
 810    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 811        let project_state = self.projects.get(&project.entity_id());
 812
 813        let index_state = project_state.map(|state| {
 814            state
 815                .syntax_index
 816                .read_with(cx, |index, _cx| index.state().clone())
 817        });
 818        let options = self.options.clone();
 819        let snapshot = buffer.read(cx).snapshot();
 820        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 821            return Task::ready(Err(anyhow!("No file path for excerpt")));
 822        };
 823        let worktree_snapshots = project
 824            .read(cx)
 825            .worktrees(cx)
 826            .map(|worktree| worktree.read(cx).snapshot())
 827            .collect::<Vec<_>>();
 828
 829        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 830            let mut path = f.worktree.read(cx).absolutize(&f.path);
 831            if path.pop() { Some(path) } else { None }
 832        });
 833
 834        cx.background_spawn(async move {
 835            let index_state = if let Some(index_state) = index_state {
 836                Some(index_state.lock_owned().await)
 837            } else {
 838                None
 839            };
 840
 841            let cursor_point = position.to_point(&snapshot);
 842
 843            let debug_info = true;
 844            EditPredictionContext::gather_context(
 845                cursor_point,
 846                &snapshot,
 847                parent_abs_path.as_deref(),
 848                &options.context,
 849                index_state.as_deref(),
 850            )
 851            .context("Failed to select excerpt")
 852            .map(|context| {
 853                make_cloud_request(
 854                    excerpt_path.into(),
 855                    context,
 856                    // TODO pass everything
 857                    Vec::new(),
 858                    false,
 859                    Vec::new(),
 860                    false,
 861                    None,
 862                    debug_info,
 863                    &worktree_snapshots,
 864                    index_state.as_deref(),
 865                    Some(options.max_prompt_bytes),
 866                    options.prompt_format,
 867                )
 868            })
 869        })
 870    }
 871
 872    pub fn wait_for_initial_indexing(
 873        &mut self,
 874        project: &Entity<Project>,
 875        cx: &mut App,
 876    ) -> Task<Result<()>> {
 877        let zeta_project = self.get_or_init_zeta_project(project, cx);
 878        zeta_project
 879            .syntax_index
 880            .read(cx)
 881            .wait_for_initial_file_indexing(cx)
 882    }
 883}
 884
 885#[derive(Error, Debug)]
 886#[error(
 887    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 888)]
 889pub struct ZedUpdateRequiredError {
 890    minimum_version: SemanticVersion,
 891}
 892
 893fn make_cloud_request(
 894    excerpt_path: Arc<Path>,
 895    context: EditPredictionContext,
 896    events: Vec<predict_edits_v3::Event>,
 897    can_collect_data: bool,
 898    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 899    diagnostic_groups_truncated: bool,
 900    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 901    debug_info: bool,
 902    worktrees: &Vec<worktree::Snapshot>,
 903    index_state: Option<&SyntaxIndexState>,
 904    prompt_max_bytes: Option<usize>,
 905    prompt_format: PromptFormat,
 906) -> predict_edits_v3::PredictEditsRequest {
 907    let mut signatures = Vec::new();
 908    let mut declaration_to_signature_index = HashMap::default();
 909    let mut referenced_declarations = Vec::new();
 910
 911    for snippet in context.declarations {
 912        let project_entry_id = snippet.declaration.project_entry_id();
 913        let Some(path) = worktrees.iter().find_map(|worktree| {
 914            worktree.entry_for_id(project_entry_id).map(|entry| {
 915                let mut full_path = RelPathBuf::new();
 916                full_path.push(worktree.root_name());
 917                full_path.push(&entry.path);
 918                full_path
 919            })
 920        }) else {
 921            continue;
 922        };
 923
 924        let parent_index = index_state.and_then(|index_state| {
 925            snippet.declaration.parent().and_then(|parent| {
 926                add_signature(
 927                    parent,
 928                    &mut declaration_to_signature_index,
 929                    &mut signatures,
 930                    index_state,
 931                )
 932            })
 933        });
 934
 935        let (text, text_is_truncated) = snippet.declaration.item_text();
 936        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 937            path: path.as_std_path().into(),
 938            text: text.into(),
 939            range: snippet.declaration.item_line_range(),
 940            text_is_truncated,
 941            signature_range: snippet.declaration.signature_range_in_item_text(),
 942            parent_index,
 943            signature_score: snippet.score(DeclarationStyle::Signature),
 944            declaration_score: snippet.score(DeclarationStyle::Declaration),
 945            score_components: snippet.components,
 946        });
 947    }
 948
 949    let excerpt_parent = index_state.and_then(|index_state| {
 950        context
 951            .excerpt
 952            .parent_declarations
 953            .last()
 954            .and_then(|(parent, _)| {
 955                add_signature(
 956                    *parent,
 957                    &mut declaration_to_signature_index,
 958                    &mut signatures,
 959                    index_state,
 960                )
 961            })
 962    });
 963
 964    predict_edits_v3::PredictEditsRequest {
 965        excerpt_path,
 966        excerpt: context.excerpt_text.body,
 967        excerpt_line_range: context.excerpt.line_range,
 968        excerpt_range: context.excerpt.range,
 969        cursor_point: predict_edits_v3::Point {
 970            line: predict_edits_v3::Line(context.cursor_point.row),
 971            column: context.cursor_point.column,
 972        },
 973        referenced_declarations,
 974        signatures,
 975        excerpt_parent,
 976        events,
 977        can_collect_data,
 978        diagnostic_groups,
 979        diagnostic_groups_truncated,
 980        git_info,
 981        debug_info,
 982        prompt_max_bytes,
 983        prompt_format,
 984    }
 985}
 986
 987fn add_signature(
 988    declaration_id: DeclarationId,
 989    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 990    signatures: &mut Vec<Signature>,
 991    index: &SyntaxIndexState,
 992) -> Option<usize> {
 993    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 994        return Some(*signature_index);
 995    }
 996    let Some(parent_declaration) = index.declaration(declaration_id) else {
 997        log::error!("bug: missing parent declaration");
 998        return None;
 999    };
1000    let parent_index = parent_declaration.parent().and_then(|parent| {
1001        add_signature(parent, declaration_to_signature_index, signatures, index)
1002    });
1003    let (text, text_is_truncated) = parent_declaration.signature_text();
1004    let signature_index = signatures.len();
1005    signatures.push(Signature {
1006        text: text.into(),
1007        text_is_truncated,
1008        parent_index,
1009        range: parent_declaration.signature_line_range(),
1010    });
1011    declaration_to_signature_index.insert(declaration_id, signature_index);
1012    Some(signature_index)
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use std::{
1018        path::{Path, PathBuf},
1019        sync::Arc,
1020    };
1021
1022    use client::UserStore;
1023    use clock::FakeSystemClock;
1024    use cloud_llm_client::predict_edits_v3::{self, Point};
1025    use edit_prediction_context::Line;
1026    use futures::{
1027        AsyncReadExt, StreamExt,
1028        channel::{mpsc, oneshot},
1029    };
1030    use gpui::{
1031        Entity, TestAppContext,
1032        http_client::{FakeHttpClient, Response},
1033        prelude::*,
1034    };
1035    use indoc::indoc;
1036    use language::{LanguageServerId, OffsetRangeExt as _};
1037    use pretty_assertions::{assert_eq, assert_matches};
1038    use project::{FakeFs, Project};
1039    use serde_json::json;
1040    use settings::SettingsStore;
1041    use util::path;
1042    use uuid::Uuid;
1043
1044    use crate::{BufferEditPrediction, Zeta};
1045
1046    #[gpui::test]
1047    async fn test_current_state(cx: &mut TestAppContext) {
1048        let (zeta, mut req_rx) = init_test(cx);
1049        let fs = FakeFs::new(cx.executor());
1050        fs.insert_tree(
1051            "/root",
1052            json!({
1053                "1.txt": "Hello!\nHow\nBye",
1054                "2.txt": "Hola!\nComo\nAdios"
1055            }),
1056        )
1057        .await;
1058        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1059
1060        zeta.update(cx, |zeta, cx| {
1061            zeta.register_project(&project, cx);
1062        });
1063
1064        let buffer1 = project
1065            .update(cx, |project, cx| {
1066                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1067                project.open_buffer(path, cx)
1068            })
1069            .await
1070            .unwrap();
1071        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1072        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1073
1074        // Prediction for current file
1075
1076        let prediction_task = zeta.update(cx, |zeta, cx| {
1077            zeta.refresh_prediction(&project, &buffer1, position, cx)
1078        });
1079        let (_request, respond_tx) = req_rx.next().await.unwrap();
1080        respond_tx
1081            .send(predict_edits_v3::PredictEditsResponse {
1082                request_id: Uuid::new_v4(),
1083                edits: vec![predict_edits_v3::Edit {
1084                    path: Path::new(path!("root/1.txt")).into(),
1085                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1086                    content: "Hello!\nHow are you?\nBye".into(),
1087                }],
1088                debug_info: None,
1089            })
1090            .unwrap();
1091        prediction_task.await.unwrap();
1092
1093        zeta.read_with(cx, |zeta, cx| {
1094            let prediction = zeta
1095                .current_prediction_for_buffer(&buffer1, &project, cx)
1096                .unwrap();
1097            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1098        });
1099
1100        // Prediction for another file
1101        let prediction_task = zeta.update(cx, |zeta, cx| {
1102            zeta.refresh_prediction(&project, &buffer1, position, cx)
1103        });
1104        let (_request, respond_tx) = req_rx.next().await.unwrap();
1105        respond_tx
1106            .send(predict_edits_v3::PredictEditsResponse {
1107                request_id: Uuid::new_v4(),
1108                edits: vec![predict_edits_v3::Edit {
1109                    path: Path::new(path!("root/2.txt")).into(),
1110                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1111                    content: "Hola!\nComo estas?\nAdios".into(),
1112                }],
1113                debug_info: None,
1114            })
1115            .unwrap();
1116        prediction_task.await.unwrap();
1117        zeta.read_with(cx, |zeta, cx| {
1118            let prediction = zeta
1119                .current_prediction_for_buffer(&buffer1, &project, cx)
1120                .unwrap();
1121            assert_matches!(
1122                prediction,
1123                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1124            );
1125        });
1126
1127        let buffer2 = project
1128            .update(cx, |project, cx| {
1129                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1130                project.open_buffer(path, cx)
1131            })
1132            .await
1133            .unwrap();
1134
1135        zeta.read_with(cx, |zeta, cx| {
1136            let prediction = zeta
1137                .current_prediction_for_buffer(&buffer2, &project, cx)
1138                .unwrap();
1139            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1140        });
1141    }
1142
1143    #[gpui::test]
1144    async fn test_simple_request(cx: &mut TestAppContext) {
1145        let (zeta, mut req_rx) = init_test(cx);
1146        let fs = FakeFs::new(cx.executor());
1147        fs.insert_tree(
1148            "/root",
1149            json!({
1150                "foo.md":  "Hello!\nHow\nBye"
1151            }),
1152        )
1153        .await;
1154        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1155
1156        let buffer = project
1157            .update(cx, |project, cx| {
1158                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1159                project.open_buffer(path, cx)
1160            })
1161            .await
1162            .unwrap();
1163        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1164        let position = snapshot.anchor_before(language::Point::new(1, 3));
1165
1166        let prediction_task = zeta.update(cx, |zeta, cx| {
1167            zeta.request_prediction(&project, &buffer, position, cx)
1168        });
1169
1170        let (request, respond_tx) = req_rx.next().await.unwrap();
1171        assert_eq!(
1172            request.excerpt_path.as_ref(),
1173            Path::new(path!("root/foo.md"))
1174        );
1175        assert_eq!(
1176            request.cursor_point,
1177            Point {
1178                line: Line(1),
1179                column: 3
1180            }
1181        );
1182
1183        respond_tx
1184            .send(predict_edits_v3::PredictEditsResponse {
1185                request_id: Uuid::new_v4(),
1186                edits: vec![predict_edits_v3::Edit {
1187                    path: Path::new(path!("root/foo.md")).into(),
1188                    range: Line(0)..Line(snapshot.max_point().row + 1),
1189                    content: "Hello!\nHow are you?\nBye".into(),
1190                }],
1191                debug_info: None,
1192            })
1193            .unwrap();
1194
1195        let prediction = prediction_task.await.unwrap().unwrap();
1196
1197        assert_eq!(prediction.edits.len(), 1);
1198        assert_eq!(
1199            prediction.edits[0].0.to_point(&snapshot).start,
1200            language::Point::new(1, 3)
1201        );
1202        assert_eq!(prediction.edits[0].1, " are you?");
1203    }
1204
1205    #[gpui::test]
1206    async fn test_request_events(cx: &mut TestAppContext) {
1207        let (zeta, mut req_rx) = init_test(cx);
1208        let fs = FakeFs::new(cx.executor());
1209        fs.insert_tree(
1210            "/root",
1211            json!({
1212                "foo.md": "Hello!\n\nBye"
1213            }),
1214        )
1215        .await;
1216        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1217
1218        let buffer = project
1219            .update(cx, |project, cx| {
1220                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1221                project.open_buffer(path, cx)
1222            })
1223            .await
1224            .unwrap();
1225
1226        zeta.update(cx, |zeta, cx| {
1227            zeta.register_buffer(&buffer, &project, cx);
1228        });
1229
1230        buffer.update(cx, |buffer, cx| {
1231            buffer.edit(vec![(7..7, "How")], None, cx);
1232        });
1233
1234        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1235        let position = snapshot.anchor_before(language::Point::new(1, 3));
1236
1237        let prediction_task = zeta.update(cx, |zeta, cx| {
1238            zeta.request_prediction(&project, &buffer, position, cx)
1239        });
1240
1241        let (request, respond_tx) = req_rx.next().await.unwrap();
1242
1243        assert_eq!(request.events.len(), 1);
1244        assert_eq!(
1245            request.events[0],
1246            predict_edits_v3::Event::BufferChange {
1247                path: Some(PathBuf::from(path!("root/foo.md"))),
1248                old_path: None,
1249                diff: indoc! {"
1250                        @@ -1,3 +1,3 @@
1251                         Hello!
1252                        -
1253                        +How
1254                         Bye
1255                    "}
1256                .to_string(),
1257                predicted: false
1258            }
1259        );
1260
1261        respond_tx
1262            .send(predict_edits_v3::PredictEditsResponse {
1263                request_id: Uuid::new_v4(),
1264                edits: vec![predict_edits_v3::Edit {
1265                    path: Path::new(path!("root/foo.md")).into(),
1266                    range: Line(0)..Line(snapshot.max_point().row + 1),
1267                    content: "Hello!\nHow are you?\nBye".into(),
1268                }],
1269                debug_info: None,
1270            })
1271            .unwrap();
1272
1273        let prediction = prediction_task.await.unwrap().unwrap();
1274
1275        assert_eq!(prediction.edits.len(), 1);
1276        assert_eq!(
1277            prediction.edits[0].0.to_point(&snapshot).start,
1278            language::Point::new(1, 3)
1279        );
1280        assert_eq!(prediction.edits[0].1, " are you?");
1281    }
1282
1283    #[gpui::test]
1284    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1285        let (zeta, mut req_rx) = init_test(cx);
1286        let fs = FakeFs::new(cx.executor());
1287        fs.insert_tree(
1288            "/root",
1289            json!({
1290                "foo.md": "Hello!\nBye"
1291            }),
1292        )
1293        .await;
1294        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1295
1296        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1297        let diagnostic = lsp::Diagnostic {
1298            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1299            severity: Some(lsp::DiagnosticSeverity::ERROR),
1300            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1301            ..Default::default()
1302        };
1303
1304        project.update(cx, |project, cx| {
1305            project.lsp_store().update(cx, |lsp_store, cx| {
1306                // Create some diagnostics
1307                lsp_store
1308                    .update_diagnostics(
1309                        LanguageServerId(0),
1310                        lsp::PublishDiagnosticsParams {
1311                            uri: path_to_buffer_uri.clone(),
1312                            diagnostics: vec![diagnostic],
1313                            version: None,
1314                        },
1315                        None,
1316                        language::DiagnosticSourceKind::Pushed,
1317                        &[],
1318                        cx,
1319                    )
1320                    .unwrap();
1321            });
1322        });
1323
1324        let buffer = project
1325            .update(cx, |project, cx| {
1326                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1327                project.open_buffer(path, cx)
1328            })
1329            .await
1330            .unwrap();
1331
1332        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1333        let position = snapshot.anchor_before(language::Point::new(0, 0));
1334
1335        let _prediction_task = zeta.update(cx, |zeta, cx| {
1336            zeta.request_prediction(&project, &buffer, position, cx)
1337        });
1338
1339        let (request, _respond_tx) = req_rx.next().await.unwrap();
1340
1341        assert_eq!(request.diagnostic_groups.len(), 1);
1342        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1343            .unwrap();
1344        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1345        assert_eq!(
1346            value,
1347            json!({
1348                "entries": [{
1349                    "range": {
1350                        "start": 8,
1351                        "end": 10
1352                    },
1353                    "diagnostic": {
1354                        "source": null,
1355                        "code": null,
1356                        "code_description": null,
1357                        "severity": 1,
1358                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1359                        "markdown": null,
1360                        "group_id": 0,
1361                        "is_primary": true,
1362                        "is_disk_based": false,
1363                        "is_unnecessary": false,
1364                        "source_kind": "Pushed",
1365                        "data": null,
1366                        "underline": true
1367                    }
1368                }],
1369                "primary_ix": 0
1370            })
1371        );
1372    }
1373
1374    fn init_test(
1375        cx: &mut TestAppContext,
1376    ) -> (
1377        Entity<Zeta>,
1378        mpsc::UnboundedReceiver<(
1379            predict_edits_v3::PredictEditsRequest,
1380            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1381        )>,
1382    ) {
1383        cx.update(move |cx| {
1384            let settings_store = SettingsStore::test(cx);
1385            cx.set_global(settings_store);
1386            language::init(cx);
1387            Project::init_settings(cx);
1388
1389            let (req_tx, req_rx) = mpsc::unbounded();
1390
1391            let http_client = FakeHttpClient::create({
1392                move |req| {
1393                    let uri = req.uri().path().to_string();
1394                    let mut body = req.into_body();
1395                    let req_tx = req_tx.clone();
1396                    async move {
1397                        let resp = match uri.as_str() {
1398                            "/client/llm_tokens" => serde_json::to_string(&json!({
1399                                "token": "test"
1400                            }))
1401                            .unwrap(),
1402                            "/predict_edits/v3" => {
1403                                let mut buf = Vec::new();
1404                                body.read_to_end(&mut buf).await.ok();
1405                                let req = serde_json::from_slice(&buf).unwrap();
1406
1407                                let (res_tx, res_rx) = oneshot::channel();
1408                                req_tx.unbounded_send((req, res_tx)).unwrap();
1409                                serde_json::to_string(&res_rx.await?).unwrap()
1410                            }
1411                            _ => {
1412                                panic!("Unexpected path: {}", uri)
1413                            }
1414                        };
1415
1416                        Ok(Response::builder().body(resp.into()).unwrap())
1417                    }
1418                }
1419            });
1420
1421            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1422            client.cloud_client().set_credentials(1, "test".into());
1423
1424            language_model::init(client.clone(), cx);
1425
1426            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1427            let zeta = Zeta::global(&client, &user_store, cx);
1428
1429            (zeta, req_rx)
1430        })
1431    }
1432}