zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
   9use edit_prediction_context::{
  10    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  11    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::{mpsc, oneshot};
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::BufferSnapshot;
  21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::EditPrediction;
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  47    use_imports: true,
  48    excerpt: EditPredictionExcerptOptions {
  49        max_bytes: 512,
  50        min_bytes: 128,
  51        target_before_cursor_over_total_bytes: 0.5,
  52    },
  53    score: EditPredictionScoreOptions {
  54        omit_excerpt_overlaps: true,
  55    },
  56};
  57
  58pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  59    context: DEFAULT_CONTEXT_OPTIONS,
  60    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  61    max_diagnostic_bytes: 2048,
  62    prompt_format: PromptFormat::DEFAULT,
  63    file_indexing_parallelism: 1,
  64};
  65
  66#[derive(Clone)]
  67struct ZetaGlobal(Entity<Zeta>);
  68
  69impl Global for ZetaGlobal {}
  70
  71pub struct Zeta {
  72    client: Arc<Client>,
  73    user_store: Entity<UserStore>,
  74    llm_token: LlmApiToken,
  75    _llm_token_subscription: Subscription,
  76    projects: HashMap<EntityId, ZetaProject>,
  77    options: ZetaOptions,
  78    update_required: bool,
  79    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  80}
  81
  82#[derive(Debug, Clone, PartialEq)]
  83pub struct ZetaOptions {
  84    pub context: EditPredictionContextOptions,
  85    pub max_prompt_bytes: usize,
  86    pub max_diagnostic_bytes: usize,
  87    pub prompt_format: predict_edits_v3::PromptFormat,
  88    pub file_indexing_parallelism: usize,
  89}
  90
  91pub struct PredictionDebugInfo {
  92    pub context: EditPredictionContext,
  93    pub retrieval_time: TimeDelta,
  94    pub buffer: WeakEntity<Buffer>,
  95    pub position: language::Anchor,
  96    pub local_prompt: Result<String, String>,
  97    pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
  98}
  99
 100pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 101
 102struct ZetaProject {
 103    syntax_index: Entity<SyntaxIndex>,
 104    events: VecDeque<Event>,
 105    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 106    current_prediction: Option<CurrentEditPrediction>,
 107}
 108
 109#[derive(Debug, Clone)]
 110struct CurrentEditPrediction {
 111    pub requested_by_buffer_id: EntityId,
 112    pub prediction: EditPrediction,
 113}
 114
 115impl CurrentEditPrediction {
 116    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 117        let Some(new_edits) = self
 118            .prediction
 119            .interpolate(&self.prediction.buffer.read(cx))
 120        else {
 121            return false;
 122        };
 123
 124        if self.prediction.buffer != old_prediction.prediction.buffer {
 125            return true;
 126        }
 127
 128        let Some(old_edits) = old_prediction
 129            .prediction
 130            .interpolate(&old_prediction.prediction.buffer.read(cx))
 131        else {
 132            return true;
 133        };
 134
 135        // This reduces the occurrence of UI thrash from replacing edits
 136        //
 137        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 138        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 139            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 140            && old_edits.len() == 1
 141            && new_edits.len() == 1
 142        {
 143            let (old_range, old_text) = &old_edits[0];
 144            let (new_range, new_text) = &new_edits[0];
 145            new_range == old_range && new_text.starts_with(old_text)
 146        } else {
 147            true
 148        }
 149    }
 150}
 151
 152/// A prediction from the perspective of a buffer.
 153#[derive(Debug)]
 154enum BufferEditPrediction<'a> {
 155    Local { prediction: &'a EditPrediction },
 156    Jump { prediction: &'a EditPrediction },
 157}
 158
 159struct RegisteredBuffer {
 160    snapshot: BufferSnapshot,
 161    _subscriptions: [gpui::Subscription; 2],
 162}
 163
 164#[derive(Clone)]
 165pub enum Event {
 166    BufferChange {
 167        old_snapshot: BufferSnapshot,
 168        new_snapshot: BufferSnapshot,
 169        timestamp: Instant,
 170    },
 171}
 172
 173impl Zeta {
 174    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 175        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 176    }
 177
 178    pub fn global(
 179        client: &Arc<Client>,
 180        user_store: &Entity<UserStore>,
 181        cx: &mut App,
 182    ) -> Entity<Self> {
 183        cx.try_global::<ZetaGlobal>()
 184            .map(|global| global.0.clone())
 185            .unwrap_or_else(|| {
 186                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 187                cx.set_global(ZetaGlobal(zeta.clone()));
 188                zeta
 189            })
 190    }
 191
 192    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 193        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 194
 195        Self {
 196            projects: HashMap::new(),
 197            client,
 198            user_store,
 199            options: DEFAULT_OPTIONS,
 200            llm_token: LlmApiToken::default(),
 201            _llm_token_subscription: cx.subscribe(
 202                &refresh_llm_token_listener,
 203                |this, _listener, _event, cx| {
 204                    let client = this.client.clone();
 205                    let llm_token = this.llm_token.clone();
 206                    cx.spawn(async move |_this, _cx| {
 207                        llm_token.refresh(&client).await?;
 208                        anyhow::Ok(())
 209                    })
 210                    .detach_and_log_err(cx);
 211                },
 212            ),
 213            update_required: false,
 214            debug_tx: None,
 215        }
 216    }
 217
 218    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 219        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 220        self.debug_tx = Some(debug_watch_tx);
 221        debug_watch_rx
 222    }
 223
 224    pub fn options(&self) -> &ZetaOptions {
 225        &self.options
 226    }
 227
 228    pub fn set_options(&mut self, options: ZetaOptions) {
 229        self.options = options;
 230    }
 231
 232    pub fn clear_history(&mut self) {
 233        for zeta_project in self.projects.values_mut() {
 234            zeta_project.events.clear();
 235        }
 236    }
 237
 238    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 239        self.user_store.read(cx).edit_prediction_usage()
 240    }
 241
 242    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 243        self.get_or_init_zeta_project(project, cx);
 244    }
 245
 246    pub fn register_buffer(
 247        &mut self,
 248        buffer: &Entity<Buffer>,
 249        project: &Entity<Project>,
 250        cx: &mut Context<Self>,
 251    ) {
 252        let zeta_project = self.get_or_init_zeta_project(project, cx);
 253        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 254    }
 255
 256    fn get_or_init_zeta_project(
 257        &mut self,
 258        project: &Entity<Project>,
 259        cx: &mut App,
 260    ) -> &mut ZetaProject {
 261        self.projects
 262            .entry(project.entity_id())
 263            .or_insert_with(|| ZetaProject {
 264                syntax_index: cx.new(|cx| {
 265                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 266                }),
 267                events: VecDeque::new(),
 268                registered_buffers: HashMap::new(),
 269                current_prediction: None,
 270            })
 271    }
 272
 273    fn register_buffer_impl<'a>(
 274        zeta_project: &'a mut ZetaProject,
 275        buffer: &Entity<Buffer>,
 276        project: &Entity<Project>,
 277        cx: &mut Context<Self>,
 278    ) -> &'a mut RegisteredBuffer {
 279        let buffer_id = buffer.entity_id();
 280        match zeta_project.registered_buffers.entry(buffer_id) {
 281            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 282            hash_map::Entry::Vacant(entry) => {
 283                let snapshot = buffer.read(cx).snapshot();
 284                let project_entity_id = project.entity_id();
 285                entry.insert(RegisteredBuffer {
 286                    snapshot,
 287                    _subscriptions: [
 288                        cx.subscribe(buffer, {
 289                            let project = project.downgrade();
 290                            move |this, buffer, event, cx| {
 291                                if let language::BufferEvent::Edited = event
 292                                    && let Some(project) = project.upgrade()
 293                                {
 294                                    this.report_changes_for_buffer(&buffer, &project, cx);
 295                                }
 296                            }
 297                        }),
 298                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 299                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 300                            else {
 301                                return;
 302                            };
 303                            zeta_project.registered_buffers.remove(&buffer_id);
 304                        }),
 305                    ],
 306                })
 307            }
 308        }
 309    }
 310
 311    fn report_changes_for_buffer(
 312        &mut self,
 313        buffer: &Entity<Buffer>,
 314        project: &Entity<Project>,
 315        cx: &mut Context<Self>,
 316    ) -> BufferSnapshot {
 317        let zeta_project = self.get_or_init_zeta_project(project, cx);
 318        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 319
 320        let new_snapshot = buffer.read(cx).snapshot();
 321        if new_snapshot.version != registered_buffer.snapshot.version {
 322            let old_snapshot =
 323                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 324            Self::push_event(
 325                zeta_project,
 326                Event::BufferChange {
 327                    old_snapshot,
 328                    new_snapshot: new_snapshot.clone(),
 329                    timestamp: Instant::now(),
 330                },
 331            );
 332        }
 333
 334        new_snapshot
 335    }
 336
 337    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 338        let events = &mut zeta_project.events;
 339
 340        if let Some(Event::BufferChange {
 341            new_snapshot: last_new_snapshot,
 342            timestamp: last_timestamp,
 343            ..
 344        }) = events.back_mut()
 345        {
 346            // Coalesce edits for the same buffer when they happen one after the other.
 347            let Event::BufferChange {
 348                old_snapshot,
 349                new_snapshot,
 350                timestamp,
 351            } = &event;
 352
 353            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 354                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 355                && old_snapshot.version == last_new_snapshot.version
 356            {
 357                *last_new_snapshot = new_snapshot.clone();
 358                *last_timestamp = *timestamp;
 359                return;
 360            }
 361        }
 362
 363        if events.len() >= MAX_EVENT_COUNT {
 364            // These are halved instead of popping to improve prompt caching.
 365            events.drain(..MAX_EVENT_COUNT / 2);
 366        }
 367
 368        events.push_back(event);
 369    }
 370
 371    fn current_prediction_for_buffer(
 372        &self,
 373        buffer: &Entity<Buffer>,
 374        project: &Entity<Project>,
 375        cx: &App,
 376    ) -> Option<BufferEditPrediction<'_>> {
 377        let project_state = self.projects.get(&project.entity_id())?;
 378
 379        let CurrentEditPrediction {
 380            requested_by_buffer_id,
 381            prediction,
 382        } = project_state.current_prediction.as_ref()?;
 383
 384        if prediction.targets_buffer(buffer.read(cx), cx) {
 385            Some(BufferEditPrediction::Local { prediction })
 386        } else if *requested_by_buffer_id == buffer.entity_id() {
 387            Some(BufferEditPrediction::Jump { prediction })
 388        } else {
 389            None
 390        }
 391    }
 392
 393    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
 394        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 395            project_state.current_prediction.take();
 396        };
 397        // TODO report accepted
 398    }
 399
 400    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 401        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 402            project_state.current_prediction.take();
 403        };
 404    }
 405
 406    pub fn refresh_prediction(
 407        &mut self,
 408        project: &Entity<Project>,
 409        buffer: &Entity<Buffer>,
 410        position: language::Anchor,
 411        cx: &mut Context<Self>,
 412    ) -> Task<Result<()>> {
 413        let request_task = self.request_prediction(project, buffer, position, cx);
 414        let buffer = buffer.clone();
 415        let project = project.clone();
 416
 417        cx.spawn(async move |this, cx| {
 418            if let Some(prediction) = request_task.await? {
 419                this.update(cx, |this, cx| {
 420                    let project_state = this
 421                        .projects
 422                        .get_mut(&project.entity_id())
 423                        .context("Project not found")?;
 424
 425                    let new_prediction = CurrentEditPrediction {
 426                        requested_by_buffer_id: buffer.entity_id(),
 427                        prediction: prediction,
 428                    };
 429
 430                    if project_state
 431                        .current_prediction
 432                        .as_ref()
 433                        .is_none_or(|old_prediction| {
 434                            new_prediction.should_replace_prediction(&old_prediction, cx)
 435                        })
 436                    {
 437                        project_state.current_prediction = Some(new_prediction);
 438                    }
 439                    anyhow::Ok(())
 440                })??;
 441            }
 442            Ok(())
 443        })
 444    }
 445
 446    fn request_prediction(
 447        &mut self,
 448        project: &Entity<Project>,
 449        buffer: &Entity<Buffer>,
 450        position: language::Anchor,
 451        cx: &mut Context<Self>,
 452    ) -> Task<Result<Option<EditPrediction>>> {
 453        let project_state = self.projects.get(&project.entity_id());
 454
 455        let index_state = project_state.map(|state| {
 456            state
 457                .syntax_index
 458                .read_with(cx, |index, _cx| index.state().clone())
 459        });
 460        let options = self.options.clone();
 461        let snapshot = buffer.read(cx).snapshot();
 462        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 463            return Task::ready(Err(anyhow!("No file path for excerpt")));
 464        };
 465        let client = self.client.clone();
 466        let llm_token = self.llm_token.clone();
 467        let app_version = AppVersion::global(cx);
 468        let worktree_snapshots = project
 469            .read(cx)
 470            .worktrees(cx)
 471            .map(|worktree| worktree.read(cx).snapshot())
 472            .collect::<Vec<_>>();
 473        let debug_tx = self.debug_tx.clone();
 474
 475        let events = project_state
 476            .map(|state| {
 477                state
 478                    .events
 479                    .iter()
 480                    .filter_map(|event| match event {
 481                        Event::BufferChange {
 482                            old_snapshot,
 483                            new_snapshot,
 484                            ..
 485                        } => {
 486                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 487
 488                            let old_path = old_snapshot.file().and_then(|f| {
 489                                let old_path = f.full_path(cx);
 490                                if Some(&old_path) != path.as_ref() {
 491                                    Some(old_path)
 492                                } else {
 493                                    None
 494                                }
 495                            });
 496
 497                            // TODO [zeta2] move to bg?
 498                            let diff =
 499                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 500
 501                            if path == old_path && diff.is_empty() {
 502                                None
 503                            } else {
 504                                Some(predict_edits_v3::Event::BufferChange {
 505                                    old_path,
 506                                    path,
 507                                    diff,
 508                                    //todo: Actually detect if this edit was predicted or not
 509                                    predicted: false,
 510                                })
 511                            }
 512                        }
 513                    })
 514                    .collect::<Vec<_>>()
 515            })
 516            .unwrap_or_default();
 517
 518        let diagnostics = snapshot.diagnostic_sets().clone();
 519
 520        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 521            let mut path = f.worktree.read(cx).absolutize(&f.path);
 522            if path.pop() { Some(path) } else { None }
 523        });
 524
 525        let request_task = cx.background_spawn({
 526            let snapshot = snapshot.clone();
 527            let buffer = buffer.clone();
 528            async move {
 529                let index_state = if let Some(index_state) = index_state {
 530                    Some(index_state.lock_owned().await)
 531                } else {
 532                    None
 533                };
 534
 535                let cursor_offset = position.to_offset(&snapshot);
 536                let cursor_point = cursor_offset.to_point(&snapshot);
 537
 538                let before_retrieval = chrono::Utc::now();
 539
 540                let Some(context) = EditPredictionContext::gather_context(
 541                    cursor_point,
 542                    &snapshot,
 543                    parent_abs_path.as_deref(),
 544                    &options.context,
 545                    index_state.as_deref(),
 546                ) else {
 547                    return Ok(None);
 548                };
 549
 550                let retrieval_time = chrono::Utc::now() - before_retrieval;
 551
 552                let (diagnostic_groups, diagnostic_groups_truncated) =
 553                    Self::gather_nearby_diagnostics(
 554                        cursor_offset,
 555                        &diagnostics,
 556                        &snapshot,
 557                        options.max_diagnostic_bytes,
 558                    );
 559
 560                let debug_context = debug_tx.map(|tx| (tx, context.clone()));
 561
 562                let request = make_cloud_request(
 563                    excerpt_path,
 564                    context,
 565                    events,
 566                    // TODO data collection
 567                    false,
 568                    diagnostic_groups,
 569                    diagnostic_groups_truncated,
 570                    None,
 571                    debug_context.is_some(),
 572                    &worktree_snapshots,
 573                    index_state.as_deref(),
 574                    Some(options.max_prompt_bytes),
 575                    options.prompt_format,
 576                );
 577
 578                let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
 579                    let (response_tx, response_rx) = oneshot::channel();
 580
 581                    let local_prompt = PlannedPrompt::populate(&request)
 582                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 583                        .map_err(|err| err.to_string());
 584
 585                    debug_tx
 586                        .unbounded_send(PredictionDebugInfo {
 587                            context,
 588                            retrieval_time,
 589                            buffer: buffer.downgrade(),
 590                            local_prompt,
 591                            position,
 592                            response_rx,
 593                        })
 594                        .ok();
 595                    Some(response_tx)
 596                } else {
 597                    None
 598                };
 599
 600                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 601                    if let Some(debug_response_tx) = debug_response_tx {
 602                        debug_response_tx
 603                            .send(Err("Request skipped".to_string()))
 604                            .ok();
 605                    }
 606                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 607                }
 608
 609                let response = Self::perform_request(client, llm_token, app_version, request).await;
 610
 611                if let Some(debug_response_tx) = debug_response_tx {
 612                    debug_response_tx
 613                        .send(response.as_ref().map_err(|err| err.to_string()).and_then(
 614                            |response| match some_or_debug_panic(response.0.debug_info.clone()) {
 615                                Some(debug_info) => Ok(debug_info),
 616                                None => Err("Missing debug info".to_string()),
 617                            },
 618                        ))
 619                        .ok();
 620                }
 621
 622                anyhow::Ok(Some(response?))
 623            }
 624        });
 625
 626        let buffer = buffer.clone();
 627
 628        cx.spawn({
 629            let project = project.clone();
 630            async move |this, cx| {
 631                match request_task.await {
 632                    Ok(Some((response, usage))) => {
 633                        if let Some(usage) = usage {
 634                            this.update(cx, |this, cx| {
 635                                this.user_store.update(cx, |user_store, cx| {
 636                                    user_store.update_edit_prediction_usage(usage, cx);
 637                                });
 638                            })
 639                            .ok();
 640                        }
 641
 642                        let prediction = EditPrediction::from_response(
 643                            response, &snapshot, &buffer, &project, cx,
 644                        )
 645                        .await;
 646
 647                        // TODO telemetry: duration, etc
 648                        Ok(prediction)
 649                    }
 650                    Ok(None) => Ok(None),
 651                    Err(err) => {
 652                        if err.is::<ZedUpdateRequiredError>() {
 653                            cx.update(|cx| {
 654                                this.update(cx, |this, _cx| {
 655                                    this.update_required = true;
 656                                })
 657                                .ok();
 658
 659                                let error_message: SharedString = err.to_string().into();
 660                                show_app_notification(
 661                                    NotificationId::unique::<ZedUpdateRequiredError>(),
 662                                    cx,
 663                                    move |cx| {
 664                                        cx.new(|cx| {
 665                                            ErrorMessagePrompt::new(error_message.clone(), cx)
 666                                                .with_link_button(
 667                                                    "Update Zed",
 668                                                    "https://zed.dev/releases",
 669                                                )
 670                                        })
 671                                    },
 672                                );
 673                            })
 674                            .ok();
 675                        }
 676
 677                        Err(err)
 678                    }
 679                }
 680            }
 681        })
 682    }
 683
 684    async fn perform_request(
 685        client: Arc<Client>,
 686        llm_token: LlmApiToken,
 687        app_version: SemanticVersion,
 688        request: predict_edits_v3::PredictEditsRequest,
 689    ) -> Result<(
 690        predict_edits_v3::PredictEditsResponse,
 691        Option<EditPredictionUsage>,
 692    )> {
 693        let http_client = client.http_client();
 694        let mut token = llm_token.acquire(&client).await?;
 695        let mut did_retry = false;
 696
 697        loop {
 698            let request_builder = http_client::Request::builder().method(Method::POST);
 699            let request_builder =
 700                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 701                    request_builder.uri(predict_edits_url)
 702                } else {
 703                    request_builder.uri(
 704                        http_client
 705                            .build_zed_llm_url("/predict_edits/v3", &[])?
 706                            .as_ref(),
 707                    )
 708                };
 709            let request = request_builder
 710                .header("Content-Type", "application/json")
 711                .header("Authorization", format!("Bearer {}", token))
 712                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 713                .body(serde_json::to_string(&request)?.into())?;
 714
 715            let mut response = http_client.send(request).await?;
 716
 717            if let Some(minimum_required_version) = response
 718                .headers()
 719                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 720                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 721            {
 722                anyhow::ensure!(
 723                    app_version >= minimum_required_version,
 724                    ZedUpdateRequiredError {
 725                        minimum_version: minimum_required_version
 726                    }
 727                );
 728            }
 729
 730            if response.status().is_success() {
 731                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 732
 733                let mut body = Vec::new();
 734                response.body_mut().read_to_end(&mut body).await?;
 735                return Ok((serde_json::from_slice(&body)?, usage));
 736            } else if !did_retry
 737                && response
 738                    .headers()
 739                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 740                    .is_some()
 741            {
 742                did_retry = true;
 743                token = llm_token.refresh(&client).await?;
 744            } else {
 745                let mut body = String::new();
 746                response.body_mut().read_to_string(&mut body).await?;
 747                anyhow::bail!(
 748                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 749                    response.status(),
 750                    body
 751                );
 752            }
 753        }
 754    }
 755
 756    fn gather_nearby_diagnostics(
 757        cursor_offset: usize,
 758        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 759        snapshot: &BufferSnapshot,
 760        max_diagnostics_bytes: usize,
 761    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 762        // TODO: Could make this more efficient
 763        let mut diagnostic_groups = Vec::new();
 764        for (language_server_id, diagnostics) in diagnostic_sets {
 765            let mut groups = Vec::new();
 766            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 767            diagnostic_groups.extend(
 768                groups
 769                    .into_iter()
 770                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 771            );
 772        }
 773
 774        // sort by proximity to cursor
 775        diagnostic_groups.sort_by_key(|group| {
 776            let range = &group.entries[group.primary_ix].range;
 777            if range.start >= cursor_offset {
 778                range.start - cursor_offset
 779            } else if cursor_offset >= range.end {
 780                cursor_offset - range.end
 781            } else {
 782                (cursor_offset - range.start).min(range.end - cursor_offset)
 783            }
 784        });
 785
 786        let mut results = Vec::new();
 787        let mut diagnostic_groups_truncated = false;
 788        let mut diagnostics_byte_count = 0;
 789        for group in diagnostic_groups {
 790            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 791            diagnostics_byte_count += raw_value.get().len();
 792            if diagnostics_byte_count > max_diagnostics_bytes {
 793                diagnostic_groups_truncated = true;
 794                break;
 795            }
 796            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 797        }
 798
 799        (results, diagnostic_groups_truncated)
 800    }
 801
 802    // TODO: Dedupe with similar code in request_prediction?
 803    pub fn cloud_request_for_zeta_cli(
 804        &mut self,
 805        project: &Entity<Project>,
 806        buffer: &Entity<Buffer>,
 807        position: language::Anchor,
 808        cx: &mut Context<Self>,
 809    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 810        let project_state = self.projects.get(&project.entity_id());
 811
 812        let index_state = project_state.map(|state| {
 813            state
 814                .syntax_index
 815                .read_with(cx, |index, _cx| index.state().clone())
 816        });
 817        let options = self.options.clone();
 818        let snapshot = buffer.read(cx).snapshot();
 819        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 820            return Task::ready(Err(anyhow!("No file path for excerpt")));
 821        };
 822        let worktree_snapshots = project
 823            .read(cx)
 824            .worktrees(cx)
 825            .map(|worktree| worktree.read(cx).snapshot())
 826            .collect::<Vec<_>>();
 827
 828        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 829            let mut path = f.worktree.read(cx).absolutize(&f.path);
 830            if path.pop() { Some(path) } else { None }
 831        });
 832
 833        cx.background_spawn(async move {
 834            let index_state = if let Some(index_state) = index_state {
 835                Some(index_state.lock_owned().await)
 836            } else {
 837                None
 838            };
 839
 840            let cursor_point = position.to_point(&snapshot);
 841
 842            let debug_info = true;
 843            EditPredictionContext::gather_context(
 844                cursor_point,
 845                &snapshot,
 846                parent_abs_path.as_deref(),
 847                &options.context,
 848                index_state.as_deref(),
 849            )
 850            .context("Failed to select excerpt")
 851            .map(|context| {
 852                make_cloud_request(
 853                    excerpt_path.into(),
 854                    context,
 855                    // TODO pass everything
 856                    Vec::new(),
 857                    false,
 858                    Vec::new(),
 859                    false,
 860                    None,
 861                    debug_info,
 862                    &worktree_snapshots,
 863                    index_state.as_deref(),
 864                    Some(options.max_prompt_bytes),
 865                    options.prompt_format,
 866                )
 867            })
 868        })
 869    }
 870
 871    pub fn wait_for_initial_indexing(
 872        &mut self,
 873        project: &Entity<Project>,
 874        cx: &mut App,
 875    ) -> Task<Result<()>> {
 876        let zeta_project = self.get_or_init_zeta_project(project, cx);
 877        zeta_project
 878            .syntax_index
 879            .read(cx)
 880            .wait_for_initial_file_indexing(cx)
 881    }
 882}
 883
 884#[derive(Error, Debug)]
 885#[error(
 886    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 887)]
 888pub struct ZedUpdateRequiredError {
 889    minimum_version: SemanticVersion,
 890}
 891
 892fn make_cloud_request(
 893    excerpt_path: Arc<Path>,
 894    context: EditPredictionContext,
 895    events: Vec<predict_edits_v3::Event>,
 896    can_collect_data: bool,
 897    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 898    diagnostic_groups_truncated: bool,
 899    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 900    debug_info: bool,
 901    worktrees: &Vec<worktree::Snapshot>,
 902    index_state: Option<&SyntaxIndexState>,
 903    prompt_max_bytes: Option<usize>,
 904    prompt_format: PromptFormat,
 905) -> predict_edits_v3::PredictEditsRequest {
 906    let mut signatures = Vec::new();
 907    let mut declaration_to_signature_index = HashMap::default();
 908    let mut referenced_declarations = Vec::new();
 909
 910    for snippet in context.declarations {
 911        let project_entry_id = snippet.declaration.project_entry_id();
 912        let Some(path) = worktrees.iter().find_map(|worktree| {
 913            worktree.entry_for_id(project_entry_id).map(|entry| {
 914                let mut full_path = RelPathBuf::new();
 915                full_path.push(worktree.root_name());
 916                full_path.push(&entry.path);
 917                full_path
 918            })
 919        }) else {
 920            continue;
 921        };
 922
 923        let parent_index = index_state.and_then(|index_state| {
 924            snippet.declaration.parent().and_then(|parent| {
 925                add_signature(
 926                    parent,
 927                    &mut declaration_to_signature_index,
 928                    &mut signatures,
 929                    index_state,
 930                )
 931            })
 932        });
 933
 934        let (text, text_is_truncated) = snippet.declaration.item_text();
 935        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 936            path: path.as_std_path().into(),
 937            text: text.into(),
 938            range: snippet.declaration.item_line_range(),
 939            text_is_truncated,
 940            signature_range: snippet.declaration.signature_range_in_item_text(),
 941            parent_index,
 942            signature_score: snippet.score(DeclarationStyle::Signature),
 943            declaration_score: snippet.score(DeclarationStyle::Declaration),
 944            score_components: snippet.components,
 945        });
 946    }
 947
 948    let excerpt_parent = index_state.and_then(|index_state| {
 949        context
 950            .excerpt
 951            .parent_declarations
 952            .last()
 953            .and_then(|(parent, _)| {
 954                add_signature(
 955                    *parent,
 956                    &mut declaration_to_signature_index,
 957                    &mut signatures,
 958                    index_state,
 959                )
 960            })
 961    });
 962
 963    predict_edits_v3::PredictEditsRequest {
 964        excerpt_path,
 965        excerpt: context.excerpt_text.body,
 966        excerpt_line_range: context.excerpt.line_range,
 967        excerpt_range: context.excerpt.range,
 968        cursor_point: predict_edits_v3::Point {
 969            line: predict_edits_v3::Line(context.cursor_point.row),
 970            column: context.cursor_point.column,
 971        },
 972        referenced_declarations,
 973        signatures,
 974        excerpt_parent,
 975        events,
 976        can_collect_data,
 977        diagnostic_groups,
 978        diagnostic_groups_truncated,
 979        git_info,
 980        debug_info,
 981        prompt_max_bytes,
 982        prompt_format,
 983    }
 984}
 985
 986fn add_signature(
 987    declaration_id: DeclarationId,
 988    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 989    signatures: &mut Vec<Signature>,
 990    index: &SyntaxIndexState,
 991) -> Option<usize> {
 992    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 993        return Some(*signature_index);
 994    }
 995    let Some(parent_declaration) = index.declaration(declaration_id) else {
 996        log::error!("bug: missing parent declaration");
 997        return None;
 998    };
 999    let parent_index = parent_declaration.parent().and_then(|parent| {
1000        add_signature(parent, declaration_to_signature_index, signatures, index)
1001    });
1002    let (text, text_is_truncated) = parent_declaration.signature_text();
1003    let signature_index = signatures.len();
1004    signatures.push(Signature {
1005        text: text.into(),
1006        text_is_truncated,
1007        parent_index,
1008        range: parent_declaration.signature_line_range(),
1009    });
1010    declaration_to_signature_index.insert(declaration_id, signature_index);
1011    Some(signature_index)
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016    use std::{
1017        path::{Path, PathBuf},
1018        sync::Arc,
1019    };
1020
1021    use client::UserStore;
1022    use clock::FakeSystemClock;
1023    use cloud_llm_client::predict_edits_v3::{self, Point};
1024    use edit_prediction_context::Line;
1025    use futures::{
1026        AsyncReadExt, StreamExt,
1027        channel::{mpsc, oneshot},
1028    };
1029    use gpui::{
1030        Entity, TestAppContext,
1031        http_client::{FakeHttpClient, Response},
1032        prelude::*,
1033    };
1034    use indoc::indoc;
1035    use language::{LanguageServerId, OffsetRangeExt as _};
1036    use pretty_assertions::{assert_eq, assert_matches};
1037    use project::{FakeFs, Project};
1038    use serde_json::json;
1039    use settings::SettingsStore;
1040    use util::path;
1041    use uuid::Uuid;
1042
1043    use crate::{BufferEditPrediction, Zeta};
1044
1045    #[gpui::test]
1046    async fn test_current_state(cx: &mut TestAppContext) {
1047        let (zeta, mut req_rx) = init_test(cx);
1048        let fs = FakeFs::new(cx.executor());
1049        fs.insert_tree(
1050            "/root",
1051            json!({
1052                "1.txt": "Hello!\nHow\nBye",
1053                "2.txt": "Hola!\nComo\nAdios"
1054            }),
1055        )
1056        .await;
1057        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1058
1059        zeta.update(cx, |zeta, cx| {
1060            zeta.register_project(&project, cx);
1061        });
1062
1063        let buffer1 = project
1064            .update(cx, |project, cx| {
1065                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1066                project.open_buffer(path, cx)
1067            })
1068            .await
1069            .unwrap();
1070        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1071        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1072
1073        // Prediction for current file
1074
1075        let prediction_task = zeta.update(cx, |zeta, cx| {
1076            zeta.refresh_prediction(&project, &buffer1, position, cx)
1077        });
1078        let (_request, respond_tx) = req_rx.next().await.unwrap();
1079        respond_tx
1080            .send(predict_edits_v3::PredictEditsResponse {
1081                request_id: Uuid::new_v4(),
1082                edits: vec![predict_edits_v3::Edit {
1083                    path: Path::new(path!("root/1.txt")).into(),
1084                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1085                    content: "Hello!\nHow are you?\nBye".into(),
1086                }],
1087                debug_info: None,
1088            })
1089            .unwrap();
1090        prediction_task.await.unwrap();
1091
1092        zeta.read_with(cx, |zeta, cx| {
1093            let prediction = zeta
1094                .current_prediction_for_buffer(&buffer1, &project, cx)
1095                .unwrap();
1096            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1097        });
1098
1099        // Prediction for another file
1100        let prediction_task = zeta.update(cx, |zeta, cx| {
1101            zeta.refresh_prediction(&project, &buffer1, position, cx)
1102        });
1103        let (_request, respond_tx) = req_rx.next().await.unwrap();
1104        respond_tx
1105            .send(predict_edits_v3::PredictEditsResponse {
1106                request_id: Uuid::new_v4(),
1107                edits: vec![predict_edits_v3::Edit {
1108                    path: Path::new(path!("root/2.txt")).into(),
1109                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1110                    content: "Hola!\nComo estas?\nAdios".into(),
1111                }],
1112                debug_info: None,
1113            })
1114            .unwrap();
1115        prediction_task.await.unwrap();
1116        zeta.read_with(cx, |zeta, cx| {
1117            let prediction = zeta
1118                .current_prediction_for_buffer(&buffer1, &project, cx)
1119                .unwrap();
1120            assert_matches!(
1121                prediction,
1122                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1123            );
1124        });
1125
1126        let buffer2 = project
1127            .update(cx, |project, cx| {
1128                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1129                project.open_buffer(path, cx)
1130            })
1131            .await
1132            .unwrap();
1133
1134        zeta.read_with(cx, |zeta, cx| {
1135            let prediction = zeta
1136                .current_prediction_for_buffer(&buffer2, &project, cx)
1137                .unwrap();
1138            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1139        });
1140    }
1141
1142    #[gpui::test]
1143    async fn test_simple_request(cx: &mut TestAppContext) {
1144        let (zeta, mut req_rx) = init_test(cx);
1145        let fs = FakeFs::new(cx.executor());
1146        fs.insert_tree(
1147            "/root",
1148            json!({
1149                "foo.md":  "Hello!\nHow\nBye"
1150            }),
1151        )
1152        .await;
1153        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1154
1155        let buffer = project
1156            .update(cx, |project, cx| {
1157                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1158                project.open_buffer(path, cx)
1159            })
1160            .await
1161            .unwrap();
1162        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1163        let position = snapshot.anchor_before(language::Point::new(1, 3));
1164
1165        let prediction_task = zeta.update(cx, |zeta, cx| {
1166            zeta.request_prediction(&project, &buffer, position, cx)
1167        });
1168
1169        let (request, respond_tx) = req_rx.next().await.unwrap();
1170        assert_eq!(
1171            request.excerpt_path.as_ref(),
1172            Path::new(path!("root/foo.md"))
1173        );
1174        assert_eq!(
1175            request.cursor_point,
1176            Point {
1177                line: Line(1),
1178                column: 3
1179            }
1180        );
1181
1182        respond_tx
1183            .send(predict_edits_v3::PredictEditsResponse {
1184                request_id: Uuid::new_v4(),
1185                edits: vec![predict_edits_v3::Edit {
1186                    path: Path::new(path!("root/foo.md")).into(),
1187                    range: Line(0)..Line(snapshot.max_point().row + 1),
1188                    content: "Hello!\nHow are you?\nBye".into(),
1189                }],
1190                debug_info: None,
1191            })
1192            .unwrap();
1193
1194        let prediction = prediction_task.await.unwrap().unwrap();
1195
1196        assert_eq!(prediction.edits.len(), 1);
1197        assert_eq!(
1198            prediction.edits[0].0.to_point(&snapshot).start,
1199            language::Point::new(1, 3)
1200        );
1201        assert_eq!(prediction.edits[0].1, " are you?");
1202    }
1203
1204    #[gpui::test]
1205    async fn test_request_events(cx: &mut TestAppContext) {
1206        let (zeta, mut req_rx) = init_test(cx);
1207        let fs = FakeFs::new(cx.executor());
1208        fs.insert_tree(
1209            "/root",
1210            json!({
1211                "foo.md": "Hello!\n\nBye"
1212            }),
1213        )
1214        .await;
1215        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1216
1217        let buffer = project
1218            .update(cx, |project, cx| {
1219                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1220                project.open_buffer(path, cx)
1221            })
1222            .await
1223            .unwrap();
1224
1225        zeta.update(cx, |zeta, cx| {
1226            zeta.register_buffer(&buffer, &project, cx);
1227        });
1228
1229        buffer.update(cx, |buffer, cx| {
1230            buffer.edit(vec![(7..7, "How")], None, cx);
1231        });
1232
1233        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1234        let position = snapshot.anchor_before(language::Point::new(1, 3));
1235
1236        let prediction_task = zeta.update(cx, |zeta, cx| {
1237            zeta.request_prediction(&project, &buffer, position, cx)
1238        });
1239
1240        let (request, respond_tx) = req_rx.next().await.unwrap();
1241
1242        assert_eq!(request.events.len(), 1);
1243        assert_eq!(
1244            request.events[0],
1245            predict_edits_v3::Event::BufferChange {
1246                path: Some(PathBuf::from(path!("root/foo.md"))),
1247                old_path: None,
1248                diff: indoc! {"
1249                        @@ -1,3 +1,3 @@
1250                         Hello!
1251                        -
1252                        +How
1253                         Bye
1254                    "}
1255                .to_string(),
1256                predicted: false
1257            }
1258        );
1259
1260        respond_tx
1261            .send(predict_edits_v3::PredictEditsResponse {
1262                request_id: Uuid::new_v4(),
1263                edits: vec![predict_edits_v3::Edit {
1264                    path: Path::new(path!("root/foo.md")).into(),
1265                    range: Line(0)..Line(snapshot.max_point().row + 1),
1266                    content: "Hello!\nHow are you?\nBye".into(),
1267                }],
1268                debug_info: None,
1269            })
1270            .unwrap();
1271
1272        let prediction = prediction_task.await.unwrap().unwrap();
1273
1274        assert_eq!(prediction.edits.len(), 1);
1275        assert_eq!(
1276            prediction.edits[0].0.to_point(&snapshot).start,
1277            language::Point::new(1, 3)
1278        );
1279        assert_eq!(prediction.edits[0].1, " are you?");
1280    }
1281
1282    #[gpui::test]
1283    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1284        let (zeta, mut req_rx) = init_test(cx);
1285        let fs = FakeFs::new(cx.executor());
1286        fs.insert_tree(
1287            "/root",
1288            json!({
1289                "foo.md": "Hello!\nBye"
1290            }),
1291        )
1292        .await;
1293        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1294
1295        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1296        let diagnostic = lsp::Diagnostic {
1297            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1298            severity: Some(lsp::DiagnosticSeverity::ERROR),
1299            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1300            ..Default::default()
1301        };
1302
1303        project.update(cx, |project, cx| {
1304            project.lsp_store().update(cx, |lsp_store, cx| {
1305                // Create some diagnostics
1306                lsp_store
1307                    .update_diagnostics(
1308                        LanguageServerId(0),
1309                        lsp::PublishDiagnosticsParams {
1310                            uri: path_to_buffer_uri.clone(),
1311                            diagnostics: vec![diagnostic],
1312                            version: None,
1313                        },
1314                        None,
1315                        language::DiagnosticSourceKind::Pushed,
1316                        &[],
1317                        cx,
1318                    )
1319                    .unwrap();
1320            });
1321        });
1322
1323        let buffer = project
1324            .update(cx, |project, cx| {
1325                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1326                project.open_buffer(path, cx)
1327            })
1328            .await
1329            .unwrap();
1330
1331        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1332        let position = snapshot.anchor_before(language::Point::new(0, 0));
1333
1334        let _prediction_task = zeta.update(cx, |zeta, cx| {
1335            zeta.request_prediction(&project, &buffer, position, cx)
1336        });
1337
1338        let (request, _respond_tx) = req_rx.next().await.unwrap();
1339
1340        assert_eq!(request.diagnostic_groups.len(), 1);
1341        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1342            .unwrap();
1343        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1344        assert_eq!(
1345            value,
1346            json!({
1347                "entries": [{
1348                    "range": {
1349                        "start": 8,
1350                        "end": 10
1351                    },
1352                    "diagnostic": {
1353                        "source": null,
1354                        "code": null,
1355                        "code_description": null,
1356                        "severity": 1,
1357                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1358                        "markdown": null,
1359                        "group_id": 0,
1360                        "is_primary": true,
1361                        "is_disk_based": false,
1362                        "is_unnecessary": false,
1363                        "source_kind": "Pushed",
1364                        "data": null,
1365                        "underline": true
1366                    }
1367                }],
1368                "primary_ix": 0
1369            })
1370        );
1371    }
1372
1373    fn init_test(
1374        cx: &mut TestAppContext,
1375    ) -> (
1376        Entity<Zeta>,
1377        mpsc::UnboundedReceiver<(
1378            predict_edits_v3::PredictEditsRequest,
1379            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1380        )>,
1381    ) {
1382        cx.update(move |cx| {
1383            let settings_store = SettingsStore::test(cx);
1384            cx.set_global(settings_store);
1385            language::init(cx);
1386            Project::init_settings(cx);
1387
1388            let (req_tx, req_rx) = mpsc::unbounded();
1389
1390            let http_client = FakeHttpClient::create({
1391                move |req| {
1392                    let uri = req.uri().path().to_string();
1393                    let mut body = req.into_body();
1394                    let req_tx = req_tx.clone();
1395                    async move {
1396                        let resp = match uri.as_str() {
1397                            "/client/llm_tokens" => serde_json::to_string(&json!({
1398                                "token": "test"
1399                            }))
1400                            .unwrap(),
1401                            "/predict_edits/v3" => {
1402                                let mut buf = Vec::new();
1403                                body.read_to_end(&mut buf).await.ok();
1404                                let req = serde_json::from_slice(&buf).unwrap();
1405
1406                                let (res_tx, res_rx) = oneshot::channel();
1407                                req_tx.unbounded_send((req, res_tx)).unwrap();
1408                                serde_json::to_string(&res_rx.await?).unwrap()
1409                            }
1410                            _ => {
1411                                panic!("Unexpected path: {}", uri)
1412                            }
1413                        };
1414
1415                        Ok(Response::builder().body(resp.into()).unwrap())
1416                    }
1417                }
1418            });
1419
1420            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1421            client.cloud_client().set_credentials(1, "test".into());
1422
1423            language_model::init(client.clone(), cx);
1424
1425            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1426            let zeta = Zeta::global(&client, &user_store, cx);
1427
1428            (zeta, req_rx)
1429        })
1430    }
1431}